Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1cda3b80
Commit
1cda3b80
authored
Apr 21, 2023
by
ltqin
Browse files
add builder
parent
7b73260c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
190 additions
and
49 deletions
+190
-49
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
...nsor_operation_instance/add_device_operation_instance.hpp
+3
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
...oftmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
+26
-22
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+40
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+39
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+39
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+39
-2
src_example/01_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp
...s_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp
+4
-19
No files found.
library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
View file @
1cda3b80
...
@@ -36,6 +36,9 @@ enum struct ArchitectureEnum
...
@@ -36,6 +36,9 @@ enum struct ArchitectureEnum
};
};
template
<
typename
DeviceOp
,
ArchitectureEnum
Arch
=
ArchitectureEnum
::
Xdl
>
template
<
typename
DeviceOp
,
ArchitectureEnum
Arch
=
ArchitectureEnum
::
Xdl
>
struct
DeviceOperationInstanceCreator
;
struct
DeviceOperationInstanceCreator
;
template
<
typename
DeviceOp
,
ArchitectureEnum
Arch
>
struct
DeviceOperationInstanceBuilder
;
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/batched_gemm_softmax_gemm_permute.hpp
View file @
1cda3b80
...
@@ -35,25 +35,25 @@ template <index_t NumDimG,
...
@@ -35,25 +35,25 @@ template <index_t NumDimG,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
ArchitectureEnum
Arch
=
ArchitectureEnum
::
Xdl
>
ArchitectureEnum
Arch
>
void
add_device_instances
(
struct
DeviceOperationInstanceBuilder
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDim
G
,
NumDim
M
,
NumDim
M
,
NumDim
N
,
NumDim
N
,
NumDim
K
,
NumDim
K
,
NumDim
O
,
NumDimO
,
ADataType
,
A
DataType
,
B0
DataType
,
B0
DataType
,
B1
DataType
,
B1
DataType
,
C
DataType
,
C
DataType
,
Acc0Bias
DataType
,
Acc
0
BiasDataType
,
Acc
1
BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
A
ElementwiseOperation
,
B0
ElementwiseOperation
,
B0
ElementwiseOperation
,
C0DE
ElementwiseOperation
,
C0DE
ElementwiseOperation
,
B1
ElementwiseOperation
,
B1
ElementwiseOperation
,
C1DE
ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
,
MaskingSpec
>>>&
instances
)
Arch
>
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimM
,
...
@@ -72,9 +72,13 @@ void add_device_instances(
...
@@ -72,9 +72,13 @@ void add_device_instances(
B1ElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
;
MaskingSpec
>
;
add_device_operation_instances
(
static
void
add_device_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>&
instances
)
instances
,
DeviceOperationInstanceCreator
<
DeviceOp
,
Arch
>::
create_device_instances
());
{
}
add_device_operation_instances
(
instances
,
DeviceOperationInstanceCreator
<
DeviceOp
,
Arch
>::
create_device_instances
());
}
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
1cda3b80
...
@@ -47,7 +47,27 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -47,7 +47,27 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
...
@@ -71,7 +91,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -71,7 +91,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
1cda3b80
...
@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
...
@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
1cda3b80
...
@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
...
@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -71,7 +90,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
1cda3b80
...
@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -47,7 +47,26 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
...
@@ -70,7 +89,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
...
@@ -70,7 +89,25 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_i
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
instances
)
{
{
add_device_instances
(
instances
);
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>
;
DeviceOperationInstanceBuilder
<
DeviceOp
,
ArchitectureEnum
::
Xdl
>::
add_device_instances
(
instances
);
}
}
}
// namespace instance
}
// namespace instance
...
...
src_example/01_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp
View file @
1cda3b80
...
@@ -135,25 +135,10 @@ int main()
...
@@ -135,25 +135,10 @@ int main()
// get device op instances
// get device op instances
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
ck
::
tensor_operation
::
device
::
instance
::
add_device_instances
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceBuilder
<
2
,
DeviceOp
,
1
,
ck
::
tensor_operation
::
device
::
instance
::
ArchitectureEnum
::
Xdl
>::
1
,
add_device_instances
(
op_ptrs
);
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<
D00DataType
,
D01DataType
>
,
ck
::
Tuple
<>
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
MaskingSpec
,
ck
::
tensor_operation
::
device
::
instance
::
ArchitectureEnum
::
Xdl
>
(
op_ptrs
);
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment