Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cec9e840
Commit
cec9e840
authored
Jul 25, 2022
by
ltqin
Browse files
change name form c_thread_buffer to in_thread_buffer
parent
c0252636
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
16 deletions
+16
-16
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+13
-13
No files found.
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
View file @
cec9e840
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
32
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
32
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
32
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideA
=
K
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
cec9e840
...
@@ -28,7 +28,7 @@ struct BlockwiseSoftmax_V1
...
@@ -28,7 +28,7 @@ struct BlockwiseSoftmax_V1
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
constexpr
static
auto
c
_thread_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
static
auto
in
_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
RegSizePerXdlops
>
{}));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
RegSizePerXdlops
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
using
ThreadReduceSrcDesc_M_K
=
decltype
(
...
@@ -72,10 +72,10 @@ struct BlockwiseSoftmax_V1
...
@@ -72,10 +72,10 @@ struct BlockwiseSoftmax_V1
false
,
// ignored
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
template
<
typename
CThreadBuffer
>
template
<
typename
CThreadBuffer
>
__host__
__device__
static
void
Run
(
CThreadBuffer
&
c
_thread_buf
,
void
*
__restrict__
p_shared
)
__host__
__device__
static
void
Run
(
CThreadBuffer
&
in
_thread_buf
,
void
*
__restrict__
p_shared
)
{
{
// printf("
c
_thread_desc: {%d, %d, %d}",
c
_thread_desc.GetLength(I0).value,
// printf("
in
_thread_desc: {%d, %d, %d}",
in
_thread_desc.GetLength(I0).value,
//
c
_thread_desc.GetLength(I1).value,
c
_thread_desc.GetLength(I2).value);
//
in
_thread_desc.GetLength(I1).value,
in
_thread_desc.GetLength(I2).value);
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_shared
),
BlockSize
);
static_cast
<
AccDataType
*>
(
p_shared
),
BlockSize
);
...
@@ -89,8 +89,8 @@ struct BlockwiseSoftmax_V1
...
@@ -89,8 +89,8 @@ struct BlockwiseSoftmax_V1
// max value for one thread
// max value for one thread
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c
_offset
=
c
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
constexpr
index_t
in
_offset
=
in
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c
_thread_buf
.
GetVectorTypeReference
(
Number
<
c
_offset
>
{});
auto
&
xdlops_out
=
in
_thread_buf
.
GetVectorTypeReference
(
Number
<
in
_offset
>
{});
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
});
});
...
@@ -115,8 +115,8 @@ struct BlockwiseSoftmax_V1
...
@@ -115,8 +115,8 @@ struct BlockwiseSoftmax_V1
});
});
// calculate exp for elements
// calculate exp for elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c
_offset
=
c
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
constexpr
index_t
in
_offset
=
in
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c
_thread_buf
.
GetVectorTypeReference
(
Number
<
c
_offset
>
{});
auto
&
xdlops_out
=
in
_thread_buf
.
GetVectorTypeReference
(
Number
<
in
_offset
>
{});
static_for
<
0
,
RegSizePerXdlops
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
RegSizePerXdlops
,
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
...
@@ -125,8 +125,8 @@ struct BlockwiseSoftmax_V1
...
@@ -125,8 +125,8 @@ struct BlockwiseSoftmax_V1
});
});
// sum data
// sum data
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c
_offset
=
c
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
constexpr
index_t
in
_offset
=
in
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c
_thread_buf
.
GetVectorTypeReference
(
Number
<
c
_offset
>
{});
auto
&
xdlops_out
=
in
_thread_buf
.
GetVectorTypeReference
(
Number
<
in
_offset
>
{});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
block_sync_lds
();
});
});
...
@@ -135,10 +135,10 @@ struct BlockwiseSoftmax_V1
...
@@ -135,10 +135,10 @@ struct BlockwiseSoftmax_V1
// change elements
// change elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c
_offset
=
c
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
constexpr
index_t
in
_offset
=
in
_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c
_thread_buf
.
GetVectorTypeReference
(
Number
<
c
_offset
>
{});
auto
&
xdlops_out
=
in
_thread_buf
.
GetVectorTypeReference
(
Number
<
in
_offset
>
{});
static_for
<
0
,
c
_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
in
_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment