Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
73156add
Commit
73156add
authored
Jul 25, 2022
by
ltqin
Browse files
impletment one block
parent
848ceeb3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
10 deletions
+8
-10
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+5
-7
No files found.
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
View file @
73156add
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
128
,
32
,
32
,
4
,
8
,
16
,
16
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
16
;
ck
::
index_t
M
=
32
;
ck
::
index_t
N
=
16
;
ck
::
index_t
N
=
32
;
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideA
=
K
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
73156add
...
@@ -26,16 +26,14 @@ struct BlockwiseSoftmax_V1
...
@@ -26,16 +26,14 @@ struct BlockwiseSoftmax_V1
struct
BlockToMKMap_M0_K_M1Adapt
struct
BlockToMKMap_M0_K_M1Adapt
{
{
using
ThreadClusterLengths_M_K
=
Sequence
<
MPerXDL
,
WaveSize
/
MPerXDL
>
;
using
ThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
__host__
__device__
BlockToMKMap_M0_K_M1Adapt
()
=
default
;
__host__
__device__
BlockToMKMap_M0_K_M1Adapt
()
=
default
;
template
<
typename
TopIdx
>
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
{
const
auto
index
=
idx_top
[
I0
];
const
expr
auto
thread_cluster_desc
=
const
auto
m
=
(
index
/
WaveSize
)
*
MPerXDL
+
index
%
MPerXDL
;
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{})
;
const
auto
k
=
(
index
%
WaveSize
)
/
MPerXDL
;
return
thread_cluster_desc
.
CalculateBottomIndex
(
idx_top
);
return
make_tuple
(
m
,
k
);
}
}
};
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -59,7 +57,7 @@ struct BlockwiseSoftmax_V1
...
@@ -59,7 +57,7 @@ struct BlockwiseSoftmax_V1
false
,
// param ignored
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MPerXDL
,
WaveSize
/
MPerXDL
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MPerXDL
*
BlockSize
/
WaveSize
,
WaveSize
/
MPerXDL
>
;
using
BlockwiseMaxReduce
=
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction2
<
AccDataType
,
PartitionedBlockwiseReduction2
<
AccDataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment