Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
652728bc
Commit
652728bc
authored
Jul 25, 2022
by
ltqin
Browse files
regular code
parent
c4b19b18
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
1 deletion
+6
-1
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+5
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+1
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
652728bc
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
AccDataType
,
typename
AccDataType
,
index_t
MPerBlock
,
index_t
MPerXDL
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
RegSizePerXdlops
,
index_t
RegSizePerXdlops
,
...
@@ -22,12 +23,15 @@ struct BlockwiseSoftmax_V1
...
@@ -22,12 +23,15 @@ struct BlockwiseSoftmax_V1
{
{
static_assert
(
MRepeat
==
1
,
"Now MRepeat must equal 1"
);
static_assert
(
MRepeat
==
1
,
"Now MRepeat must equal 1"
);
static
__shared__
AccDataType
p_lex
[
MPerBlock
];
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
static_assert
(
MPerBlock
==
MPerXDL
*
BlockSize
/
WaveSize
,
"wave is only m direction"
);
struct
BlockToMKMap_M0_K_M1Adapt
struct
BlockToMKMap_M0_K_M1Adapt
{
{
__host__
__device__
BlockToMKMap_M0_K_M1Adapt
()
=
default
;
__host__
__device__
BlockToMKMap_M0_K_M1Adapt
()
=
default
;
...
@@ -57,7 +61,7 @@ struct BlockwiseSoftmax_V1
...
@@ -57,7 +61,7 @@ struct BlockwiseSoftmax_V1
false
,
// param ignored
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MPer
XDL
*
BlockSize
/
WaveSize
,
WaveSize
/
MPerXDL
>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MPer
Block
,
WaveSize
/
MPerXDL
>
;
using
BlockwiseMaxReduce
=
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction2
<
AccDataType
,
PartitionedBlockwiseReduction2
<
AccDataType
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
652728bc
...
@@ -476,6 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -476,6 +476,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
using
BlockwiseSoftmax
=
BlockwiseSoftmax_V1
<
BlockSize
,
using
BlockwiseSoftmax
=
BlockwiseSoftmax_V1
<
BlockSize
,
FloatAcc
,
FloatAcc
,
MPerBlock
,
MPerXDL
,
MPerXDL
,
NPerXDL
,
NPerXDL
,
blockwise_gemm
.
GetRegSizePerXdlops
(),
blockwise_gemm
.
GetRegSizePerXdlops
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment