Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d4368d77
Commit
d4368d77
authored
Jul 27, 2022
by
ltqin
Browse files
max and sum save to register
parent
2717e60d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
27 deletions
+7
-27
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+5
-25
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
d4368d77
...
...
@@ -43,9 +43,6 @@ struct BlockwiseSoftmax_V1
}
};
static
constexpr
auto
softmax_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerBlock
>
{},
Number
<
2
>
{}));
constexpr
static
auto
in_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
RegSizePerXdlops
>
{}));
...
...
@@ -91,19 +88,10 @@ struct BlockwiseSoftmax_V1
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
template
<
typename
CThreadBuffer
>
__host__
__device__
static
void
Run
(
CThreadBuffer
&
in_thread_buf
,
void
*
__restrict__
p_reduce
,
void
*
__restrict__
p_
softmax
)
Run
(
CThreadBuffer
&
in_thread_buf
,
float
&
f_sum
,
float
&
f_max
,
void
*
__restrict__
p_
reduce
)
{
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_reduce
),
BlockSize
);
auto
softmax_lds_buffer
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_softmax
),
MPerBlock
*
2
);
// thread id map to thread layout
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
auto
thread_cluster_idx
=
BlockToMKMap_M0_K_M1Adapt
::
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
//
// find max value
//
...
...
@@ -123,12 +111,8 @@ struct BlockwiseSoftmax_V1
// block reduce for max
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
// save max value to lds
if
(
0
==
thread_k_cluster_id
)
{
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
1
)))
=
max_value_buf
(
I0
);
}
// save max
f_max
=
max_value_buf
(
I0
);
//
// softmax
...
...
@@ -158,12 +142,8 @@ struct BlockwiseSoftmax_V1
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
// save sum to lds
if
(
0
==
thread_k_cluster_id
)
{
softmax_lds_buffer
(
softmax_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
)))
=
accu_value_buf
(
I0
);
}
// save sum
f_sum
=
accu_value_buf
(
I0
);
}
};
// namespace ck
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
d4368d77
...
...
@@ -474,7 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
num_k_block_main_loop
);
{
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
__shared__
AccDataType
p_lex
[
MPerBlock
*
2
]
;
float
f_sum
,
f_max
;
using
BlockwiseSoftmax
=
BlockwiseSoftmax_V1
<
BlockSize
,
FloatAcc
,
...
...
@@ -484,7 +484,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
blockwise_gemm
.
GetRegSizePerXdlops
(),
MXdlPerWave
,
NXdlPerWave
>
;
BlockwiseSoftmax
::
Run
(
c_thread_buf
,
p_reduce_work_buffer
,
p_lex
);
BlockwiseSoftmax
::
Run
(
c_thread_buf
,
f_sum
,
f_max
,
p_reduce_work_buffer
);
}
// output: register to global memory
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment