Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ab8e0f28
Commit
ab8e0f28
authored
Jun 20, 2022
by
Anthony Chang
Browse files
keep up with recent changes in reduction API
parent
2c1ed8b2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
12 deletions
+9
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+6
-7
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+3
-5
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
ab8e0f28
...
@@ -923,17 +923,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -923,17 +923,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ThreadwiseReduction
<
FloatReduceAcc
,
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
reduce
::
Add
<
FloatReduceAcc
>
,
reduce
::
Add
,
false
>
;
false
>
;
using
ThreadwiseReduceD1
=
using
ThreadwiseReduceD1
=
ThreadwiseReduction
<
FloatReduceAcc
,
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
reduce
::
SquaredAdd
<
FloatReduceAcc
>
,
reduce
::
SquaredAdd
,
false
>
;
false
>
;
const
auto
d0_zeroVal
=
ThreadwiseReduceD0
::
Op
::
GetIdentityValue
();
const
auto
d0_zeroVal
=
ThreadwiseReduceD0
::
Op
::
template
GetIdentityValue
<
FloatReduceAcc
>
();
const
auto
d1_zeroVal
=
ThreadwiseReduceD1
::
Op
::
GetIdentityValue
();
const
auto
d1_zeroVal
=
ThreadwiseReduceD1
::
Op
::
template
GetIdentityValue
<
FloatReduceAcc
>
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
i
)
{
d0_thread_buf
(
i
)
=
d0_zeroVal
;
});
[
&
](
auto
i
)
{
d0_thread_buf
(
i
)
=
d0_zeroVal
;
});
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
...
@@ -951,7 +951,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -951,7 +951,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockSize
,
BlockSize
,
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
// ThreadClusterLengths_M_K
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
// ThreadClusterLengths_M_K
Sequence
<
1
,
0
>
,
// ThreadClusterArrangeOrder
Sequence
<
1
,
0
>
,
// ThreadClusterArrangeOrder
reduce
::
Add
<
FloatReduceAcc
>
,
reduce
::
Add
,
false
>
;
false
>
;
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -984,8 +984,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -984,8 +984,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc
numerator
=
c_reduce_thread_buf
(
dst_offset
)
-
avg_sum
;
FloatReduceAcc
numerator
=
c_reduce_thread_buf
(
dst_offset
)
-
avg_sum
;
FloatReduceAcc
divisor
=
epsilon
+
avg_squared_sum
-
avg_sum
*
avg_sum
;
FloatReduceAcc
divisor
=
epsilon
+
avg_squared_sum
-
avg_sum
*
avg_sum
;
FloatReduceAcc
divisor_sqrt
;
FloatReduceAcc
divisor_sqrt
;
tensor_operation
::
element_wise
::
UnarySqrt
<
FloatReduceAcc
,
tensor_operation
::
element_wise
::
UnarySqrt
{}(
FloatReduceAcc
>
{}(
divisor_sqrt
,
divisor
);
divisor_sqrt
,
divisor
);
c_reduce_thread_buf
(
dst_offset
)
=
numerator
/
divisor_sqrt
;
c_reduce_thread_buf
(
dst_offset
)
=
numerator
/
divisor_sqrt
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
ab8e0f28
...
@@ -81,20 +81,19 @@ struct Add
...
@@ -81,20 +81,19 @@ struct Add
}
}
};
};
template
<
class
T
>
struct
SquaredAdd
struct
SquaredAdd
{
{
using
dataType
=
T
;
template
<
class
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
operation
==
InMemoryDataOperationEnum
::
Set
;
};
};
template
<
class
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
...
@@ -106,7 +105,6 @@ struct SquaredAdd
...
@@ -106,7 +105,6 @@ struct SquaredAdd
}
}
};
};
template
<
class
T
>
struct
Mul
struct
Mul
{
{
template
<
typename
T
>
template
<
typename
T
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment