Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f8c44314
Commit
f8c44314
authored
Jun 02, 2022
by
Anthony Chang
Browse files
reflect reduction API's recent change
parent
7e610626
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
21 deletions
+20
-21
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+19
-20
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
f8c44314
...
@@ -819,8 +819,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -819,8 +819,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3
,
3
,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock
,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock
,
1
,
1
,
true
>
(
true
>
(
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
c0_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
make_multi_index
(
block_work_idx
[
I0
],
c_reduce_thread_data_idx_begin
[
I0
],
c_reduce_thread_data_idx_begin
[
I0
],
block_work_idx
[
I1
],
block_work_idx
[
I1
],
...
@@ -837,8 +836,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -837,8 +836,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
3
,
3
,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock
,
CReduceThreadCopySrcDstScalarPerVector_NPerBlock
,
1
,
1
,
true
>
(
true
>
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
make_multi_index
(
block_work_idx
[
I0
],
c_reduce_thread_data_idx_begin
[
I0
],
c_reduce_thread_data_idx_begin
[
I0
],
block_work_idx
[
I1
],
block_work_idx
[
I1
],
...
@@ -900,7 +898,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -900,7 +898,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}(
static_for
<
0
,
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSize
(),
1
>
{}(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
FloatReduceAcc
out
;
FloatReduceAcc
out
;
acc_element_op
(
out
,
c_reduce_thread_buf
(
i
)
+
acc_element_op
(
out
,
c_reduce_thread_buf
(
i
)
+
static_cast
<
FloatReduceAcc
>
(
c0_thread_buf
(
i
)));
static_cast
<
FloatReduceAcc
>
(
c0_thread_buf
(
i
)));
c_reduce_thread_buf
(
i
)
=
out
;
// acc_element_op(acc + bias)
c_reduce_thread_buf
(
i
)
=
out
;
// acc_element_op(acc + bias)
});
});
...
@@ -933,8 +932,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -933,8 +932,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce
::
SquaredAdd
<
FloatReduceAcc
>
,
reduce
::
SquaredAdd
<
FloatReduceAcc
>
,
false
>
;
false
>
;
const
auto
d0_zeroVal
=
ThreadwiseReduceD0
::
Op
::
Get
ReductionZero
Val
();
const
auto
d0_zeroVal
=
ThreadwiseReduceD0
::
Op
::
Get
Identity
Val
ue
();
const
auto
d1_zeroVal
=
ThreadwiseReduceD1
::
Op
::
Get
ReductionZero
Val
();
const
auto
d1_zeroVal
=
ThreadwiseReduceD1
::
Op
::
Get
Identity
Val
ue
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
i
)
{
d0_thread_buf
(
i
)
=
d0_zeroVal
;
});
[
&
](
auto
i
)
{
d0_thread_buf
(
i
)
=
d0_zeroVal
;
});
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
...
...
include/ck/utility/reduction_operator.hpp
View file @
f8c44314
...
@@ -76,7 +76,7 @@ struct SquaredAdd
...
@@ -76,7 +76,7 @@ struct SquaredAdd
{
{
using
dataType
=
T
;
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
Get
ReductionZero
Val
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
T
Get
Identity
Val
ue
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
+
b
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
+
b
*
b
;
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment