Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
46233520
Commit
46233520
authored
Apr 19, 2023
by
ltqin
Browse files
add device creator
parent
ad45a6cd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
89 deletions
+113
-89
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
...nsor_operation_instance/add_device_operation_instance.hpp
+2
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
...oftmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
+19
-19
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp
...mm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp
+46
-35
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp
...mm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp
+46
-35
No files found.
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
View file @
46233520
...
...
@@ -29,6 +29,8 @@ void add_device_operation_instances(std::vector<std::unique_ptr<BaseOp>>& op_ins
});
}
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceCreator
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
View file @
46233520
...
...
@@ -10,7 +10,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp"
...
...
@@ -55,24 +54,25 @@ void add_device_instances(
C1DEElementwiseOperation
,
MaskingSpec
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
create_device_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
());
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
;
add_device_operation_instances
(
instances
,
DeviceOperationInstanceCreator
<
DeviceOp
>::
create_device_instances
());
}
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_bf16_gmk_gnk_gno_gmo_instance.hpp
View file @
46233520
...
...
@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
namespace
bf16_data
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on
>
;
}
// namespace bf16_data
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
ADataType
>,
ck
::
bhalf_t
>::
value
,
bool
>::
type
=
false
>
auto
create_device_instances
()
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceCreator
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>>
{
return
bf16_data
::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
F32
,
Acc0BiasDataType
,
C0DEElementwiseOperation
,
MaskingSpec
>
{};
}
static
auto
create_device_instances
()
{
return
bf16_data
::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
bhalf_t
,
float
,
Acc0BiasDataType
,
C0DEElementwiseOperation
,
MaskingSpec
>
{};
}
};
}
// namespace instance
}
// namespace device
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_fp16_gmk_gnk_gno_gmo_instance.hpp
View file @
46233520
...
...
@@ -14,8 +14,6 @@ namespace tensor_operation {
namespace
device
{
namespace
instance
{
namespace
fp16_data
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -65,40 +63,53 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on
>
;
}
// namespace fp16_data
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
ADataType
>,
ck
::
half_t
>::
value
,
bool
>::
type
=
false
>
auto
create_device_instances
()
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceCreator
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>>
{
return
fp16_data
::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
F32
,
Acc0BiasDataType
,
C0DEElementwiseOperation
,
MaskingSpec
>
{};
}
static
auto
create_device_instances
()
{
return
fp16_data
::
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ck
::
half_t
,
float
,
Acc0BiasDataType
,
C0DEElementwiseOperation
,
MaskingSpec
>
{};
}
};
}
// namespace instance
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment