Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
661b0454
Commit
661b0454
authored
Apr 13, 2023
by
ltqin
Browse files
change add device instances function name
parent
5595f635
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
209 additions
and
263 deletions
+209
-263
example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_nolib.cpp
...max_gemm_permute/gemm_bias_softmax_gemm_permute_nolib.cpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+78
-91
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
...oftmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
+37
-43
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp
...mm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp
+12
-31
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+20
-24
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+20
-24
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+20
-24
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+20
-24
No files found.
example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_nolib.cpp
View file @
661b0454
...
@@ -134,8 +134,8 @@ int main()
...
@@ -134,8 +134,8 @@ int main()
MaskingSpec
>
;
MaskingSpec
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
DeviceOp
>::
GetI
nstances
();
ck
::
tensor_operation
::
device
::
instance
::
add_device_i
nstances
(
op_ptrs
);
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
661b0454
...
@@ -11,7 +11,7 @@
...
@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/
device_
batched_gemm_
multiple_d_
softmax_gemm_permute
_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -19,29 +19,26 @@ namespace device {
...
@@ -19,29 +19,26 @@ namespace device {
namespace
instance
{
namespace
instance
{
extern
template
void
extern
template
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
std
::
vector
<
std
::
unique_ptr
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>,
ck
::
Tuple
<
F16
>,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ScaleAdd
,
ScaleAdd
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
extern
template
void
extern
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -63,29 +60,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -63,29 +60,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
instances
);
instances
);
extern
template
void
extern
template
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
std
::
vector
<
std
::
unique_ptr
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>,
ck
::
Tuple
<
BF16
>,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ScaleAdd
,
ScaleAdd
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
extern
template
void
extern
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -107,29 +101,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -107,29 +101,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
instances
);
instances
);
extern
template
void
extern
template
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
std
::
vector
<
std
::
unique_ptr
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
>,
ck
::
Tuple
<
>,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
extern
template
void
extern
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -151,29 +142,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -151,29 +142,26 @@ add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
instances
);
instances
);
extern
template
void
extern
template
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
std
::
vector
<
std
::
unique_ptr
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
>,
ck
::
Tuple
<
>,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
extern
template
void
extern
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -249,8 +237,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -249,8 +237,7 @@ struct DeviceOperationInstanceFactory<
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
add_device_instances
(
op_ptrs
);
op_ptrs
);
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
View file @
661b0454
...
@@ -35,50 +35,44 @@ template <index_t NumDimG,
...
@@ -35,50 +35,44 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
void
add_device_instances
(
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
B0ElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>>
MaskingSpec
>>
>&
instances
)
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
add_device_operation_instances
(
instances
,
NumDimM
,
create_device_instances
<
NumDimG
,
NumDimN
,
NumDimM
,
NumDimK
,
NumDimN
,
NumDimO
,
NumDimK
,
ADataType
,
NumDimO
,
B0DataType
,
ADataType
,
B1DataType
,
B0DataType
,
CDataType
,
B1DataType
,
Acc0BiasDataType
,
CDataType
,
Acc1BiasDataType
,
Acc0BiasDataType
,
AElementwiseOperation
,
Acc1BiasDataType
,
B0ElementwiseOperation
,
AElementwiseOperation
,
C0DEElementwiseOperation
,
B0ElementwiseOperation
,
B1ElementwiseOperation
,
C0DEElementwiseOperation
,
C1DEElementwiseOperation
,
B1ElementwiseOperation
,
MaskingSpec
>
;
C1DEElementwiseOperation
,
static
auto
GetInstances
()
MaskingSpec
>
());
{
}
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp
View file @
661b0454
...
@@ -85,38 +85,19 @@ template <index_t NumDimG,
...
@@ -85,38 +85,19 @@ template <index_t NumDimG,
typename
enable_if
<
is_same
<
remove_cvref_t
<
ADataType
>,
ck
::
half_t
>::
value
||
typename
enable_if
<
is_same
<
remove_cvref_t
<
ADataType
>,
ck
::
half_t
>::
value
||
is_same
<
remove_cvref_t
<
ADataType
>
,
ck
::
bhalf_t
>::
value
,
is_same
<
remove_cvref_t
<
ADataType
>
,
ck
::
bhalf_t
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
auto
create_device_instances
()
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>>>&
instances
)
{
{
add_device_operation_instances
(
return
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
instances
,
NumDimG
,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimM
,
NumDimG
,
NumDimN
,
NumDimM
,
NumDimK
,
NumDimN
,
NumDimO
,
NumDimK
,
ADataType
,
NumDimO
,
F32
,
ADataType
,
Acc0BiasDataType
,
F32
,
C0DEElementwiseOperation
,
Acc0BiasDataType
,
MaskingSpec
>
{};
C0DEElementwiseOperation
,
MaskingSpec
>
{});
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
661b0454
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/
device_
batched_gemm_
multiple_d_
softmax_gemm_permute
_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
template
void
template
void
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
2
,
std
::
vector
<
std
::
unique_ptr
<
1
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
BF16
>,
PassThrough
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
template
void
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
661b0454
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/
device_
batched_gemm_
multiple_d_
softmax_gemm_permute
_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
template
void
template
void
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
2
,
std
::
vector
<
std
::
unique_ptr
<
1
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
1
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
F16
>,
PassThrough
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
template
void
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
661b0454
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/
device_
batched_gemm_
multiple_d_
softmax_gemm_permute
_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
template
void
template
void
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
2
,
std
::
vector
<
std
::
unique_ptr
<
1
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
BF16
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
>,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
>,
PassThrough
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
template
void
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
661b0454
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/
device_
batched_gemm_
multiple_d_
softmax_gemm_permute
_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
...
@@ -26,30 +26,26 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
template
void
template
void
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
2
,
std
::
vector
<
std
::
unique_ptr
<
1
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
1
,
1
,
F16
,
1
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
>,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
>,
PassThrough
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
template
void
template
void
add_device_instances
(
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment