Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e1e01d8f
Commit
e1e01d8f
authored
Jul 21, 2022
by
ltqin
Browse files
add blockwsie softmax v1
parent
480d6219
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
241 additions
and
93 deletions
+241
-93
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+4
-1
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+150
-0
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+72
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+12
-89
No files found.
example/01_gemm/gemm_xdl_fp16_flash_attention.cpp
View file @
e1e01d8f
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
...
@@ -51,7 +51,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
32
,
32
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
2
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
;
// clang-format on
// clang-format on
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance1
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
...
@@ -92,8 +92,8 @@ int main(int argc, char* argv[])
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
32
;
ck
::
index_t
M
=
16
;
ck
::
index_t
N
=
32
;
ck
::
index_t
N
=
16
;
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideA
=
K
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
e1e01d8f
...
@@ -263,7 +263,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -263,7 +263,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
}
__host__
__device__
static
constexpr
auto
GetCThreadDesc
()
{
return
c_thread_desc_
;
}
__host__
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
{
return
xdlops_gemm
.
GetRegSizePerXdlops
();
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
...
...
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
0 → 100644
View file @
e1e01d8f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
AccDataType
,
index_t
MPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
RegSizePerXdlops
,
index_t
MRepeat
,
index_t
NRepeat
>
struct
BlockwiseSoftmax_V1
{
static_assert
(
MRepeat
==
1
,
"Now MRepeat must equal 1"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
WaveSize
=
64
;
static_assert
(
MPerBlock
==
MPerXDL
*
BlockSize
/
WaveSize
,
"wave is only m direction"
);
struct
BlockToMKMap_M0_K_M1Adapt
{
__host__
__device__
BlockToMKMap_M0_K_M1Adapt
()
=
default
;
template
<
typename
TopIdx
>
__host__
__device__
static
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
{
const
auto
index
=
idx_top
[
I0
];
const
auto
m
=
(
index
/
WaveSize
)
*
MPerXDL
+
index
%
MPerXDL
;
const
auto
k
=
(
index
%
WaveSize
)
/
MPerXDL
;
return
make_tuple
(
m
,
k
);
}
};
constexpr
static
auto
in_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
RegSizePerXdlops
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
RegSizePerXdlops
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{})));
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MPerBlock
,
WaveSize
/
MPerXDL
>
;
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
BlockToMKMap_M0_K_M1Adapt
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
BlockToMKMap_M0_K_M1Adapt
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
template
<
typename
CThreadBuffer
>
__host__
__device__
static
void
Run
(
CThreadBuffer
&
in_thread_buf
,
float
&
f_sum
,
float
&
f_max
,
void
*
__restrict__
p_reduce
)
{
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_reduce
),
BlockSize
);
//
// find max value
//
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
max_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
// max value for one thread
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
in_offset
=
in_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
in_thread_buf
.
GetVectorTypeReference
(
Number
<
in_offset
>
{});
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
});
// block reduce for max
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
// save max
f_max
=
max_value_buf
(
I0
);
//
// softmax
//
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// calculate exp for elements, P=exp(s-max)
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
in_offset
=
in_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
in_thread_buf
.
GetVectorTypeReference
(
Number
<
in_offset
>
{});
static_for
<
0
,
RegSizePerXdlops
,
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
math
::
exp
(
xdlops_out
.
template
AsType
<
float
>()[
iK
]
-
max_value_buf
(
I0
));
});
});
// sum data
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
in_offset
=
in_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
in_thread_buf
.
GetVectorTypeReference
(
Number
<
in_offset
>
{});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
});
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
// save sum
f_sum
=
accu_value_buf
(
I0
);
}
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
e1e01d8f
...
@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction
...
@@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction
};
};
};
};
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterDesc
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
PartitionedBlockwiseReduction2
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
ThreadClusterDesc
{};
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_buffer
,
AccDataType
&
in_out_value
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_buffer
[
offset1
];
AccDataType
opData2
=
work_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
);
work_buffer
(
offset1
)
=
opData1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_buffer
[
offset
];
};
};
// clang-format off
// clang-format off
// Assume:
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
e1e01d8f
...
@@ -13,12 +13,7 @@
...
@@ -13,12 +13,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -478,90 +473,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -478,90 +473,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
{
{
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
float
f_sum
,
f_max
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
1
,
true
>
max_value_buf
;
static_for
<
0
,
1
,
1
>
{}([
&
](
auto
I
)
{
using
BlockwiseSoftmax
=
BlockwiseSoftmax_V1
<
BlockSize
,
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
FloatAcc
,
});
MPerBlock
,
MPerXDL
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
1
,
true
>
accu_value_buf
;
NPerXDL
,
static_for
<
0
,
1
,
1
>
{}([
&
](
auto
I
)
{
blockwise_gemm
.
GetRegSizePerXdlops
(),
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
MXdlPerWave
,
});
NXdlPerWave
>
;
BlockwiseSoftmax
::
Run
(
c_thread_buf
,
f_sum
,
f_max
,
p_reduce_work_buffer
);
constexpr
auto
c_thread_desc
=
blockwise_gemm
.
GetCThreadDesc
();
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2));
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
c_thread_desc
.
GetLength
(
I2
)
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{})));
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
// const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
using
ThreadClusterLengths_M_K
=
Sequence
<
32
,
2
>
;
using
ThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
block_sync_lds
();
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
// printf("\n");
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// softmax
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
math
::
exp
(
xdlops_out
.
template
AsType
<
float
>()[
iK
]
-
max_value_buf
(
I0
));
});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
});
}
}
// output: register to global memory
// output: register to global memory
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment