Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
58188d46
Commit
58188d46
authored
Sep 15, 2022
by
Rocking
Browse files
Fuse sigmoid after groupnorm
parent
aea3b411
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
6 deletions
+22
-6
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+1
-1
example/42_groupnorm/groupnorm_sigmoid.cpp
example/42_groupnorm/groupnorm_sigmoid.cpp
+5
-5
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+15
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
...ry/reference_tensor_operation/cpu/reference_groupnorm.hpp
+1
-0
No files found.
example/42_groupnorm/CMakeLists.txt
View file @
58188d46
add_example_executable
(
example_groupnorm_blockwise groupnorm_blockwise.cpp
)
add_example_executable
(
example_groupnorm_sigmoid groupnorm_sigmoid.cpp
)
\ No newline at end of file
\ No newline at end of file
example/42_groupnorm/groupnorm_
blockwise
.cpp
→
example/42_groupnorm/groupnorm_
sigmoid
.cpp
View file @
58188d46
...
@@ -24,7 +24,7 @@ using GammaDataType = ck::half_t;
...
@@ -24,7 +24,7 @@ using GammaDataType = ck::half_t;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Sigmoid
=
ck
::
tensor_operation
::
element_wise
::
Sigmoid
;
constexpr
int
Rank
=
5
;
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
...
@@ -35,7 +35,7 @@ using DeviceInstance =
...
@@ -35,7 +35,7 @@ using DeviceInstance =
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
YDataType
,
YDataType
,
PassThrough
,
Sigmoid
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
...
@@ -91,7 +91,7 @@ int main()
...
@@ -91,7 +91,7 @@ int main()
gamma_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
PassThrough
{});
Sigmoid
{});
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
...
@@ -111,11 +111,11 @@ int main()
...
@@ -111,11 +111,11 @@ int main()
BetaDataType
,
BetaDataType
,
YDataType
,
YDataType
,
AccDataType
,
AccDataType
,
PassThrough
>
;
Sigmoid
>
;
ReferenceInstance
ref
;
ReferenceInstance
ref
;
auto
ref_argument
=
auto
ref_argument
=
ref
.
MakeArgument
(
x
,
gamma
,
beta
,
host_y
,
PassThrough
{},
{
N
,
H
,
W
,
G
,
C
},
1e-6
);
ref
.
MakeArgument
(
x
,
gamma
,
beta
,
host_y
,
Sigmoid
{},
{
N
,
H
,
W
,
G
,
C
},
1e-6
);
auto
ref_invoker
=
ref
.
MakeInvoker
();
auto
ref_invoker
=
ref
.
MakeInvoker
();
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
58188d46
...
@@ -232,6 +232,21 @@ struct Gelu
...
@@ -232,6 +232,21 @@ struct Gelu
}
}
};
};
struct
Sigmoid
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
};
int32_t
divider_
=
1
;
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp
View file @
58188d46
...
@@ -122,6 +122,7 @@ struct ReferenceGroupnorm : public device::BaseOperator
...
@@ -122,6 +122,7 @@ struct ReferenceGroupnorm : public device::BaseOperator
AccDataType
y
=
gamma
*
(
x
-
mean_val
)
/
AccDataType
y
=
gamma
*
(
x
-
mean_val
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
var_val
)
+
beta
;
beta
;
arg
.
acc_elementwise_op_
(
y
,
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
arg
.
y_
(
n
,
h
,
w
,
g
,
c
)
=
type_convert
<
YDataType
>
(
y
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment