Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
05fc2f8e
Commit
05fc2f8e
authored
Mar 23, 2023
by
ltqin
Browse files
add sourecode interface to factory
parent
cf33526e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
100 deletions
+76
-100
client_example/08_fused_attention/fused_attention_bias_mask.cpp
..._example/08_fused_attention/fused_attention_bias_mask.cpp
+2
-2
client_example/08_fused_attention/fused_attention_no_lib.cpp
client_example/08_fused_attention/fused_attention_no_lib.cpp
+69
-57
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.hpp
...ax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.hpp
+4
-1
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
...nstance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
+1
-40
No files found.
client_example/08_fused_attention/fused_attention_bias_mask.cpp
View file @
05fc2f8e
...
@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
...
@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d00_gs_ms_ns_lengths
,
d01_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
d00_gs_ms_ns_lengths
,
d01_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d0
1
_gs_ms_ns_strides
,
d01_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
d0
0
_gs_ms_ns_strides
,
d01_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
AElementOp
{},
...
@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
...
@@ -210,7 +210,7 @@ int main(int argc, char* argv[])
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d00_gs_ms_ns_lengths
,
d01_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
d00_gs_ms_ns_lengths
,
d01_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d0
1
_gs_ms_ns_strides
,
d01_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
d0
0
_gs_ms_ns_strides
,
d01_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
AElementOp
{},
...
...
client_example/08_fused_attention/fused_attention_no_lib.cpp
View file @
05fc2f8e
...
@@ -5,14 +5,14 @@
...
@@ -5,14 +5,14 @@
#include <vector>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute
/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute
_general
.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
Scale
Mask
;
using
B1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -23,6 +23,7 @@ using ADataType = ck::half_t;
...
@@ -23,6 +23,7 @@ using ADataType = ck::half_t;
using
B0DataType
=
ck
::
half_t
;
using
B0DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
D00DataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
struct
SimpleDeviceMem
struct
SimpleDeviceMem
...
@@ -41,7 +42,7 @@ struct SimpleDeviceMem
...
@@ -41,7 +42,7 @@ struct SimpleDeviceMem
void
*
p_mem_
;
void
*
p_mem_
;
};
};
int
main
()
int
main
(
int
argc
,
char
*
argv
[]
)
{
{
int
G0
=
48
;
int
G0
=
48
;
int
G1
=
16
;
int
G1
=
16
;
...
@@ -66,8 +67,13 @@ int main()
...
@@ -66,8 +67,13 @@ int main()
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
// D00 layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d00_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d00_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
G0
*
G1
*
N
*
K
);
SimpleDeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
G0
*
G1
*
N
*
K
);
SimpleDeviceMem
d00_device_buf
(
sizeof
(
D00DataType
)
*
G0
*
G1
*
M
*
N
);
SimpleDeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
G0
*
G1
*
O
*
N
);
SimpleDeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
G0
*
G1
*
O
*
N
);
SimpleDeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
G0
*
G1
*
M
*
O
);
SimpleDeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
G0
*
G1
*
M
*
O
);
...
@@ -81,7 +87,7 @@ int main()
...
@@ -81,7 +87,7 @@ int main()
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
D00DataType
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
AElementOp
,
AElementOp
,
B0ElementOp
,
B0ElementOp
,
...
@@ -89,11 +95,10 @@ int main()
...
@@ -89,11 +95,10 @@ int main()
B1ElementOp
,
B1ElementOp
,
CElementOp
,
CElementOp
,
MaskingSpec
>
;
MaskingSpec
>
;
// get device op instances
// get device op instances
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
instance
::
DeviceOp
>::
GetInstances
();
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
...
@@ -106,14 +111,15 @@ int main()
...
@@ -106,14 +111,15 @@ int main()
// profile device op instances
// profile device op instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
size_
t
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
for
(
in
t
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
{
},
// p_acc0_biases
std
::
array
<
void
*
,
1
>
{
d00_device_buf
.
GetDeviceBuffer
()
},
// p_acc0_biases
{},
// p_acc1_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
...
@@ -123,13 +129,15 @@ int main()
...
@@ -123,13 +129,15 @@ int main()
b1_gs_os_ns_strides
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
c_gs_ms_os_strides
,
{},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
{},
// acc0_biases_gs_ms_ns_strides
d00_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d00_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
AElementOp
{},
B0ElementOp
{},
B0ElementOp
{},
Acc0ElementOp
{
1
/
sqrtf
(
K
)},
Acc0ElementOp
{
1
/
sqrtf
(
K
)
,
0.1
},
B1ElementOp
{},
B1ElementOp
{},
CElementOp
{});
CElementOp
{});
...
@@ -143,7 +151,8 @@ int main()
...
@@ -143,7 +151,8 @@ int main()
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
G0
*
G1
;
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
G0
*
G1
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
)
*
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
D00DataType
)
*
M
*
N
)
*
G0
*
G1
;
G0
*
G1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
@@ -176,11 +185,12 @@ int main()
...
@@ -176,11 +185,12 @@ int main()
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
{
},
// p_acc0_biases
std
::
array
<
void
*
,
1
>
{
d00_device_buf
.
GetDeviceBuffer
()
},
// p_acc0_biases
{},
// p_acc1_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
...
@@ -190,13 +200,15 @@ int main()
...
@@ -190,13 +200,15 @@ int main()
b1_gs_os_ns_strides
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
c_gs_ms_os_strides
,
{},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
{},
// acc0_biases_gs_ms_ns_strides
d00_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d00_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
AElementOp
{},
B0ElementOp
{},
B0ElementOp
{},
Acc0ElementOp
{
1
/
sqrtf
(
K
)},
Acc0ElementOp
{
1
/
sqrtf
(
K
)
,
0.1
},
B1ElementOp
{},
B1ElementOp
{},
CElementOp
{});
CElementOp
{});
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.hpp
View file @
05fc2f8e
...
@@ -83,7 +83,10 @@ template <index_t NumDimG,
...
@@ -83,7 +83,10 @@ template <index_t NumDimG,
typename
C0DEElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
ADataType
>,
ck
::
half_t
>::
value
||
is_same
<
remove_cvref_t
<
ADataType
>
,
ck
::
bhalf_t
>::
value
,
bool
>::
type
=
false
>
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimM
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
View file @
05fc2f8e
...
@@ -11,52 +11,13 @@
...
@@ -11,52 +11,13 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment