Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
f262efc9
Commit
f262efc9
authored
Nov 21, 2022
by
yuguo
Browse files
Surpport profiler for DCU, surpport debug compiler
parent
3f56062c
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2666 additions
and
2387 deletions
+2666
-2387
CMakeLists.txt
CMakeLists.txt
+2
-2
cmake/third_party.cmake
cmake/third_party.cmake
+2
-0
oneflow/core/profiler/event.cpp
oneflow/core/profiler/event.cpp
+90
-94
oneflow/core/profiler/event.h
oneflow/core/profiler/event.h
+186
-188
oneflow/core/profiler/event_recorder.cpp
oneflow/core/profiler/event_recorder.cpp
+2
-2
oneflow/core/profiler/event_recorder.h
oneflow/core/profiler/event_recorder.h
+60
-60
oneflow/core/profiler/kernel.cpp
oneflow/core/profiler/kernel.cpp
+64
-0
oneflow/core/profiler/kineto_shim.cpp
oneflow/core/profiler/kineto_shim.cpp
+1
-1
oneflow/core/profiler/kineto_shim.h
oneflow/core/profiler/kineto_shim.h
+1
-1
oneflow/core/profiler/profile_manager.cpp
oneflow/core/profiler/profile_manager.cpp
+5
-6
oneflow/core/profiler/profile_manager.h
oneflow/core/profiler/profile_manager.h
+1
-1
oneflow/core/profiler/profiler.cpp
oneflow/core/profiler/profiler.cpp
+39
-0
oneflow/user/kernels/math_unary_elementwise_func.h
oneflow/user/kernels/math_unary_elementwise_func.h
+983
-983
oneflow/user/kernels/nvtx_range_kernel.hip.cpp
oneflow/user/kernels/nvtx_range_kernel.hip.cpp
+138
-0
oneflow/user/kernels/stateful_opkernel.cpp
oneflow/user/kernels/stateful_opkernel.cpp
+901
-901
python/oneflow/test/modules/fused_dot_feature_interaction.py
python/oneflow/test/modules/fused_dot_feature_interaction.py
+43
-0
python/oneflow/test/profiler/test_profile_lenet.py
python/oneflow/test/profiler/test_profile_lenet.py
+148
-148
No files found.
CMakeLists.txt
View file @
f262efc9
...
...
@@ -265,9 +265,9 @@ set(ROBIN_HOOD_HASHING_URL
use_mirror
(
VARIABLE ROBIN_HOOD_HASHING_URL URL
${
ROBIN_HOOD_HASHING_URL
}
)
set
(
ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467
)
set
(
FMT_URL https://github.com/fmtlib/fmt/archive/
48b7e3dafb27ece02cd6addc8bd1041c79d59c2c
.zip
)
set
(
FMT_URL https://github.com/fmtlib/fmt/archive/
fc07217d85e6dcec52878807d6bbd89a9d9156a5
.zip
)
use_mirror
(
VARIABLE FMT_URL URL
${
FMT_URL
}
)
set
(
FMT_MD5
45925a979ed7195e0c88a70be691de09
)
set
(
FMT_MD5
7d9bb2ececc9ede29cd35bdc42a7e22c
)
set
(
KINETO_URL
https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip
)
...
...
cmake/third_party.cmake
View file @
f262efc9
...
...
@@ -175,6 +175,8 @@ if (BUILD_ROCM)
add_definitions
(
-D__HIP_PLATFORM_HCC__
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
-mcmodel=large"
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
list
(
APPEND oneflow_third_party_libs hip::hipcub
)
...
...
oneflow/core/profiler/event.cpp
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// #include "fmt/core.h"
// #include "fmt/format.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/util.h"
using
json
=
nlohmann
::
json
;
namespace
oneflow
{
namespace
profiler
{
nlohmann
::
json
IEvent
::
ToJson
()
{
return
json
{{
"name"
,
name_
},
{
"time"
,
GetDuration
<
double
>
()},
{
"input_shapes"
,
"-"
}};
}
void
IEvent
::
SetStartedAt
(
double
t
)
{
started_at_
=
t
;
}
void
IEvent
::
SetFinishedAt
(
double
t
)
{
finished_at_
=
t
;
}
void
IEvent
::
Start
()
{
SetStartedAt
(
GetTimeNow
());
}
void
IEvent
::
Finish
()
{
SetFinishedAt
(
GetTimeNow
());
}
bool
IEvent
::
IsChildOf
(
const
IEvent
*
e
)
{
if
(
!
e
)
{
return
false
;
}
if
(
this
==
e
)
{
return
false
;
}
return
GetStartedAt
<
double
>
()
>=
e
->
GetStartedAt
<
double
>
()
&&
GetFinishedAt
<
double
>
()
<=
e
->
GetFinishedAt
<
double
>
();
}
const
std
::
string
&
IEvent
::
GetName
()
const
{
return
name_
;
}
std
::
string
CustomEvent
::
Key
()
{
return
name_
;
}
nlohmann
::
json
CustomEvent
::
ToJson
()
{
auto
j
=
IEvent
::
ToJson
();
j
[
"type"
]
=
EventType
::
kCustom
;
j
[
"custom_type"
]
=
type_
;
return
j
;
}
std
::
shared_ptr
<
CustomEvent
>
CustomEvent
::
Create
(
const
std
::
string
&
name
,
CustomEventType
type
)
{
return
std
::
shared_ptr
<
CustomEvent
>
(
new
CustomEvent
(
name
,
type
));
}
// std::string KernelEvent::Key() { return fmt::format("{}.{}", name_, GetFormatedInputShapes()); }
std
::
string
KernelEvent
::
Key
()
{
return
"yuguo"
;
}
nlohmann
::
json
KernelEvent
::
ToJson
()
{
auto
j
=
IEvent
::
ToJson
();
j
[
"type"
]
=
EventType
::
kOneflowKernel
;
j
[
"input_shapes"
]
=
GetFormatedInputShapes
();
#if defined(WITH_CUDA)
j
[
"memory_size"
]
=
memory_size_
;
if
(
!
children_
.
empty
())
{
j
[
"children"
]
=
children_
;
}
#endif // WITH_CUDA
return
j
;
}
std
::
shared_ptr
<
KernelEvent
>
KernelEvent
::
Create
(
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
ShapeView
>
(
void
)
>&
shape_getter
)
{
return
std
::
shared_ptr
<
KernelEvent
>
(
new
KernelEvent
(
name
,
shape_getter
));
}
void
KernelEvent
::
RecordShape
(
const
ShapeView
&
shape
)
{
input_shapes_
.
emplace_back
(
shape
);
}
std
::
string
KernelEvent
::
GetFormatedInputShapes
(
size_t
max_num_to_format
)
{
if
(
input_shapes_
.
size
()
==
0
)
{
return
"-"
;
}
std
::
vector
<
std
::
string
>
shapes_formated
(
std
::
min
(
input_shapes_
.
size
(),
max_num_to_format
));
for
(
auto
i
=
0
;
i
<
shapes_formated
.
size
();
++
i
)
{
const
std
::
string
current_shape
=
input_shapes_
[
i
].
ToString
();
shapes_formated
[
i
]
=
current_shape
==
"()"
?
"scalar"
:
current_shape
;
}
if
(
input_shapes_
.
size
()
>
max_num_to_format
)
{
shapes_formated
.
emplace_back
(
"..."
);
}
// return fmt::format("[{}]", fmt::join(shapes_formated, ", "));
return
"yuguo"
;
}
}
// namespace profiler
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "fmt/core.h"
#include "fmt/format.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/util.h"
using
json
=
nlohmann
::
json
;
namespace
oneflow
{
namespace
profiler
{
nlohmann
::
json
IEvent
::
ToJson
()
{
return
json
{{
"name"
,
name_
},
{
"time"
,
GetDuration
<
double
>
()},
{
"input_shapes"
,
"-"
}};
}
void
IEvent
::
SetStartedAt
(
double
t
)
{
started_at_
=
t
;
}
void
IEvent
::
SetFinishedAt
(
double
t
)
{
finished_at_
=
t
;
}
void
IEvent
::
Start
()
{
SetStartedAt
(
GetTimeNow
());
}
void
IEvent
::
Finish
()
{
SetFinishedAt
(
GetTimeNow
());
}
bool
IEvent
::
IsChildOf
(
const
IEvent
*
e
)
{
if
(
!
e
)
{
return
false
;
}
if
(
this
==
e
)
{
return
false
;
}
return
GetStartedAt
<
double
>
()
>=
e
->
GetStartedAt
<
double
>
()
&&
GetFinishedAt
<
double
>
()
<=
e
->
GetFinishedAt
<
double
>
();
}
const
std
::
string
&
IEvent
::
GetName
()
const
{
return
name_
;
}
std
::
string
CustomEvent
::
Key
()
{
return
name_
;
}
nlohmann
::
json
CustomEvent
::
ToJson
()
{
auto
j
=
IEvent
::
ToJson
();
j
[
"type"
]
=
EventType
::
kCustom
;
j
[
"custom_type"
]
=
type_
;
return
j
;
}
std
::
shared_ptr
<
CustomEvent
>
CustomEvent
::
Create
(
const
std
::
string
&
name
,
CustomEventType
type
)
{
return
std
::
shared_ptr
<
CustomEvent
>
(
new
CustomEvent
(
name
,
type
));
}
std
::
string
KernelEvent
::
Key
()
{
return
fmt
::
format
(
"{}.{}"
,
name_
,
GetFormatedInputShapes
());
}
nlohmann
::
json
KernelEvent
::
ToJson
()
{
auto
j
=
IEvent
::
ToJson
();
j
[
"type"
]
=
EventType
::
kOneflowKernel
;
j
[
"input_shapes"
]
=
GetFormatedInputShapes
();
#if defined(WITH_CUDA) || defined(WITH_ROCM)
j
[
"memory_size"
]
=
memory_size_
;
if
(
!
children_
.
empty
())
{
j
[
"children"
]
=
children_
;
}
#endif // WITH_CUDA
return
j
;
}
std
::
shared_ptr
<
KernelEvent
>
KernelEvent
::
Create
(
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>&
shape_getter
)
{
return
std
::
shared_ptr
<
KernelEvent
>
(
new
KernelEvent
(
name
,
shape_getter
));
}
std
::
string
KernelEvent
::
GetFormatedInputShapes
(
size_t
max_num_to_format
)
{
if
(
input_shapes_
.
size
()
==
0
)
{
return
"-"
;
}
std
::
vector
<
std
::
string
>
shapes_formated
(
std
::
min
(
input_shapes_
.
size
(),
max_num_to_format
));
for
(
auto
i
=
0
;
i
<
shapes_formated
.
size
();
++
i
)
{
const
std
::
string
current_shape
=
input_shapes_
[
i
].
ToString
();
shapes_formated
[
i
]
=
current_shape
==
"()"
?
"scalar"
:
current_shape
;
}
if
(
input_shapes_
.
size
()
>
max_num_to_format
)
{
shapes_formated
.
emplace_back
(
"..."
);
}
return
fmt
::
format
(
"[{}]"
,
fmt
::
join
(
shapes_formated
,
", "
));
}
}
// namespace profiler
}
// namespace oneflow
\ No newline at end of file
oneflow/core/profiler/event.h
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_
#include <functional>
#include <memory>
#include <vector>
#include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape_view.h"
namespace
oneflow
{
namespace
profiler
{
class
ProfileManager
;
enum
class
EventType
{
kCustom
,
// has three kinds
kOneflowKernel
// OneFlow cpu/cuda kernel
};
enum
class
CustomEventType
{
kDefault
,
// for record_function
kCudaKernel
,
// cuda kernel
kCudaRuntime
// something like cudaLaunchKernel
};
enum
class
EventTimeUnit
{
kNS
,
kUS
};
class
IEvent
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
IEvent
);
IEvent
()
=
delete
;
IEvent
(
const
std
::
string
&
name
,
EventTimeUnit
time_unit
)
:
name_
(
name
),
time_unit_
(
time_unit
)
{}
virtual
std
::
string
Key
()
=
0
;
virtual
nlohmann
::
json
ToJson
();
virtual
~
IEvent
()
=
default
;
virtual
void
Start
();
virtual
void
Finish
();
bool
IsChildOf
(
const
IEvent
*
e
);
const
std
::
string
&
GetName
()
const
;
template
<
typename
T
>
const
T
GetDuration
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
template
<
typename
T
>
const
T
GetStartedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
template
<
typename
T
>
const
T
GetFinishedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
protected:
virtual
void
SetStartedAt
(
double
t
);
virtual
void
SetFinishedAt
(
double
t
);
std
::
string
name_
;
EventTimeUnit
time_unit_
;
double
started_at_
=
0
;
double
finished_at_
=
0
;
};
inline
double
ConvertTime
(
double
time_
,
EventTimeUnit
src_time_unit
,
EventTimeUnit
dst_time_unit
)
{
if
(
src_time_unit
==
EventTimeUnit
::
kNS
&&
dst_time_unit
==
EventTimeUnit
::
kUS
)
{
return
time_
/
1000
;
}
if
(
src_time_unit
==
EventTimeUnit
::
kUS
&&
dst_time_unit
==
EventTimeUnit
::
kNS
)
{
return
time_
*
1000
;
}
return
time_
;
}
template
<
>
const
inline
double
IEvent
::
GetStartedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
ConvertTime
(
started_at_
,
time_unit_
,
time_unit
);
}
template
<
>
const
inline
time_t
IEvent
::
GetStartedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetStartedAt
<
double
>
(
time_unit
));
}
template
<
>
const
inline
double
IEvent
::
GetFinishedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
ConvertTime
(
finished_at_
,
time_unit_
,
time_unit
);
}
template
<
>
const
inline
time_t
IEvent
::
GetFinishedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetFinishedAt
<
double
>
(
time_unit
));
}
template
<
>
const
inline
double
IEvent
::
GetDuration
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
GetFinishedAt
<
double
>
(
time_unit
)
-
GetStartedAt
<
double
>
(
time_unit
);
}
template
<
>
const
inline
time_t
IEvent
::
GetDuration
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetDuration
<
double
>
(
time_unit
));
}
class
CustomEvent
final
:
public
IEvent
{
public:
friend
class
ProfileManager
;
std
::
string
Key
()
override
;
nlohmann
::
json
ToJson
()
override
;
static
std
::
shared_ptr
<
CustomEvent
>
Create
(
const
std
::
string
&
name
,
CustomEventType
type
=
CustomEventType
::
kDefault
);
private:
CustomEventType
type_
;
CustomEvent
(
const
std
::
string
&
custom_name
,
CustomEventType
type
)
:
IEvent
(
custom_name
,
type
==
CustomEventType
::
kDefault
?
EventTimeUnit
::
kNS
:
EventTimeUnit
::
kUS
),
type_
(
type
)
{}
};
class
KernelEvent
final
:
public
IEvent
{
public:
std
::
string
Key
()
override
;
nlohmann
::
json
ToJson
()
override
;
static
std
::
shared_ptr
<
KernelEvent
>
Create
(
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
ShapeView
>
(
void
)
>&
shape_getter
);
void
RecordShape
(
const
ShapeView
&
shape
);
#if defined(WITH_CUDA)
void
SetMemorySize
(
int64_t
memory_size
)
{
memory_size_
=
memory_size
;
}
void
AddChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
children_
.
emplace
(
e
);
}
bool
AddChildEventIfSo
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
if
(
e
->
IsChildOf
(
dynamic_cast
<
IEvent
*>
(
this
)))
{
children_
.
emplace
(
e
);
return
true
;
}
return
false
;
}
bool
HasChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
return
children_
.
count
(
e
);
}
void
WalkAmongChildren
(
const
std
::
function
<
void
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
>&
f
)
const
{
for
(
const
auto
&
x
:
children_
)
{
f
(
x
);
}
}
#endif // WITH_CUDA
private:
KernelEvent
(
const
std
::
string
&
kernel_name
,
const
std
::
function
<
std
::
vector
<
ShapeView
>
(
void
)
>&
shape_getter
)
:
IEvent
(
kernel_name
,
EventTimeUnit
::
kNS
)
{
if
(
shape_getter
)
{
input_shapes_
=
shape_getter
();
}
}
#if defined(WITH_CUDA)
int64_t
memory_size_
=
-
1
;
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
children_
;
#endif // WITH_CUDA
std
::
vector
<
ShapeView
>
input_shapes_
;
std
::
string
GetFormatedInputShapes
(
size_t
max_num_to_format
=
4
);
};
}
// namespace profiler
}
// namespace oneflow
namespace
nlohmann
{
inline
void
to_json
(
json
&
j
,
const
std
::
shared_ptr
<::
oneflow
::
profiler
::
IEvent
>&
event
)
{
j
=
event
->
ToJson
();
}
}
// namespace nlohmann
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_
#include <functional>
#include <memory>
#include <vector>
#include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape_view.h"
namespace
oneflow
{
namespace
profiler
{
class
ProfileManager
;
enum
class
EventType
{
kCustom
,
// has three kinds
kOneflowKernel
// OneFlow cpu/cuda kernel
};
enum
class
CustomEventType
{
kDefault
,
// for record_function
kCudaKernel
,
// cuda kernel
kCudaRuntime
// something like cudaLaunchKernel
};
enum
class
EventTimeUnit
{
kNS
,
kUS
};
class
IEvent
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
IEvent
);
IEvent
()
=
delete
;
IEvent
(
const
std
::
string
&
name
,
EventTimeUnit
time_unit
)
:
name_
(
name
),
time_unit_
(
time_unit
)
{}
virtual
std
::
string
Key
()
=
0
;
virtual
nlohmann
::
json
ToJson
();
virtual
~
IEvent
()
=
default
;
virtual
void
Start
();
virtual
void
Finish
();
bool
IsChildOf
(
const
IEvent
*
e
);
const
std
::
string
&
GetName
()
const
;
template
<
typename
T
>
const
T
GetDuration
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
template
<
typename
T
>
const
T
GetStartedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
template
<
typename
T
>
const
T
GetFinishedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
protected:
virtual
void
SetStartedAt
(
double
t
);
virtual
void
SetFinishedAt
(
double
t
);
std
::
string
name_
;
EventTimeUnit
time_unit_
;
double
started_at_
=
0
;
double
finished_at_
=
0
;
};
inline
double
ConvertTime
(
double
time_
,
EventTimeUnit
src_time_unit
,
EventTimeUnit
dst_time_unit
)
{
if
(
src_time_unit
==
EventTimeUnit
::
kNS
&&
dst_time_unit
==
EventTimeUnit
::
kUS
)
{
return
time_
/
1000
;
}
if
(
src_time_unit
==
EventTimeUnit
::
kUS
&&
dst_time_unit
==
EventTimeUnit
::
kNS
)
{
return
time_
*
1000
;
}
return
time_
;
}
template
<
>
const
inline
double
IEvent
::
GetStartedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
ConvertTime
(
started_at_
,
time_unit_
,
time_unit
);
}
template
<
>
const
inline
time_t
IEvent
::
GetStartedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetStartedAt
<
double
>
(
time_unit
));
}
template
<
>
const
inline
double
IEvent
::
GetFinishedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
ConvertTime
(
finished_at_
,
time_unit_
,
time_unit
);
}
template
<
>
const
inline
time_t
IEvent
::
GetFinishedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetFinishedAt
<
double
>
(
time_unit
));
}
template
<
>
const
inline
double
IEvent
::
GetDuration
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
GetFinishedAt
<
double
>
(
time_unit
)
-
GetStartedAt
<
double
>
(
time_unit
);
}
template
<
>
const
inline
time_t
IEvent
::
GetDuration
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetDuration
<
double
>
(
time_unit
));
}
class
CustomEvent
final
:
public
IEvent
{
public:
friend
class
ProfileManager
;
std
::
string
Key
()
override
;
nlohmann
::
json
ToJson
()
override
;
static
std
::
shared_ptr
<
CustomEvent
>
Create
(
const
std
::
string
&
name
,
CustomEventType
type
=
CustomEventType
::
kDefault
);
private:
CustomEventType
type_
;
CustomEvent
(
const
std
::
string
&
custom_name
,
CustomEventType
type
)
:
IEvent
(
custom_name
,
type
==
CustomEventType
::
kDefault
?
EventTimeUnit
::
kNS
:
EventTimeUnit
::
kUS
),
type_
(
type
)
{}
};
class
KernelEvent
final
:
public
IEvent
{
public:
std
::
string
Key
()
override
;
nlohmann
::
json
ToJson
()
override
;
static
std
::
shared_ptr
<
KernelEvent
>
Create
(
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>&
shape_getter
);
#if defined(WITH_CUDA) || defined(WITH_ROCM)
void
SetMemorySize
(
int64_t
memory_size
)
{
memory_size_
=
memory_size
;
}
void
AddChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
children_
.
emplace
(
e
);
}
bool
AddChildEventIfSo
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
if
(
e
->
IsChildOf
(
dynamic_cast
<
IEvent
*>
(
this
)))
{
children_
.
emplace
(
e
);
return
true
;
}
return
false
;
}
bool
HasChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
return
children_
.
count
(
e
);
}
void
WalkAmongChildren
(
const
std
::
function
<
void
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
>&
f
)
const
{
for
(
const
auto
&
x
:
children_
)
{
f
(
x
);
}
}
#endif // WITH_CUDA
private:
KernelEvent
(
const
std
::
string
&
kernel_name
,
const
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>&
shape_getter
)
:
IEvent
(
kernel_name
,
EventTimeUnit
::
kNS
)
{
if
(
shape_getter
)
{
input_shapes_
=
shape_getter
();
}
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
int64_t
memory_size_
=
-
1
;
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
children_
;
#endif // WITH_CUDA
std
::
vector
<
Shape
>
input_shapes_
;
std
::
string
GetFormatedInputShapes
(
size_t
max_num_to_format
=
4
);
};
}
// namespace profiler
}
// namespace oneflow
namespace
nlohmann
{
inline
void
to_json
(
json
&
j
,
const
std
::
shared_ptr
<::
oneflow
::
profiler
::
IEvent
>&
event
)
{
j
=
event
->
ToJson
();
}
}
// namespace nlohmann
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
oneflow/core/profiler/event_recorder.cpp
View file @
f262efc9
...
...
@@ -32,13 +32,13 @@ std::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const st
Maybe
<
EventRecorder
>
EventRecorder
::
CreateKernelEventRecorder
(
const
std
::
string
&
name
,
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
#endif
const
ShapeGetterFuncType
&
shape_getter
)
{
auto
pmgr
=
Singleton
<
ProfileManager
>::
Get
();
if
(
pmgr
)
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
if
(
pmgr
->
use_cpu_
||
pmgr
->
use_cuda_
)
{
auto
event
=
KernelEvent
::
Create
(
name
,
pmgr
->
record_shapes_
?
shape_getter
:
nullptr
);
if
(
pmgr
->
use_cuda_
)
{
...
...
oneflow/core/profiler/event_recorder.h
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/profiler/event.h"
namespace
oneflow
{
namespace
profiler
{
class
EventRecorder
{
public:
using
ShapeGetterFuncType
=
std
::
function
<
std
::
vector
<
Shape
View
>
(
void
)
>
;
OF_DISALLOW_COPY_AND_MOVE
(
EventRecorder
);
explicit
EventRecorder
(
const
std
::
shared_ptr
<
IEvent
>&
event
)
:
event_
(
event
)
{
CHECK_JUST
(
RegisterEventToProfileManager
(
event
));
event_
->
Start
();
}
Maybe
<
void
>
RegisterEventToProfileManager
(
const
std
::
shared_ptr
<
IEvent
>&
event
);
~
EventRecorder
()
{
if
(
event_
)
{
event_
->
Finish
();
event_
.
reset
();
}
}
static
std
::
shared_ptr
<
EventRecorder
>
CreateCustomEventRecorder
(
const
std
::
string
&
name
);
static
Maybe
<
EventRecorder
>
CreateKernelEventRecorder
(
const
std
::
string
&
name
,
#if defined(WITH_CUDA)
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
#endif
const
ShapeGetterFuncType
&
shape_getter
);
private:
std
::
shared_ptr
<
IEvent
>
event_
;
};
}
// namespace profiler
}
// namespace oneflow
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/profiler/event.h"
namespace
oneflow
{
namespace
profiler
{
class
EventRecorder
{
public:
using
ShapeGetterFuncType
=
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>
;
OF_DISALLOW_COPY_AND_MOVE
(
EventRecorder
);
explicit
EventRecorder
(
const
std
::
shared_ptr
<
IEvent
>&
event
)
:
event_
(
event
)
{
CHECK_JUST
(
RegisterEventToProfileManager
(
event
));
event_
->
Start
();
}
Maybe
<
void
>
RegisterEventToProfileManager
(
const
std
::
shared_ptr
<
IEvent
>&
event
);
~
EventRecorder
()
{
if
(
event_
)
{
event_
->
Finish
();
event_
.
reset
();
}
}
static
std
::
shared_ptr
<
EventRecorder
>
CreateCustomEventRecorder
(
const
std
::
string
&
name
);
static
Maybe
<
EventRecorder
>
CreateKernelEventRecorder
(
const
std
::
string
&
name
,
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
#endif
const
ShapeGetterFuncType
&
shape_getter
);
private:
std
::
shared_ptr
<
IEvent
>
event_
;
};
}
// namespace profiler
}
// namespace oneflow
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
oneflow/core/profiler/kernel.cpp
View file @
f262efc9
...
...
@@ -17,7 +17,11 @@ limitations under the License.
#include "oneflow/core/profiler/kernel.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/kernel/kernel.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/lazy/actor/actor_context.h"
namespace
oneflow
{
...
...
@@ -43,6 +47,11 @@ thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
thread_local
cudaEvent_t
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
#endif // WITH_CUDA
#if defined(WITH_ROCM)
thread_local
hipEvent_t
cuda_memory_bandwidth_profile_start_event
=
nullptr
;
thread_local
hipEvent_t
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
#endif // WITH_ROCM
}
// namespace
void
TraceKernelForwardDataContentStart
(
KernelContext
*
kernel_ctx
,
const
Kernel
*
kernel
)
{
...
...
@@ -61,6 +70,22 @@ void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel*
}
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_PUSH
(
kernel
->
op_conf
().
name
());
}
#endif // WITH_CUDA
#if defined(WITH_ROCM)
if
(
profile_cuda_memory_bandwidth
)
{
auto
*
actor_context_provider
=
dynamic_cast
<
ActorContextProvider
*>
(
kernel_ctx
);
auto
*
cuda_stream
=
dynamic_cast
<
ep
::
CudaStream
*>
(
kernel_ctx
->
stream
());
if
(
cuda_stream
!=
nullptr
&&
actor_context_provider
!=
nullptr
)
{
CHECK
(
cuda_memory_bandwidth_profile_start_event
==
nullptr
);
CHECK
(
cuda_memory_bandwidth_profile_end_event
==
nullptr
);
OF_CUDA_CHECK
(
hipEventCreate
(
&
cuda_memory_bandwidth_profile_start_event
));
OF_CUDA_CHECK
(
hipEventCreate
(
&
cuda_memory_bandwidth_profile_end_event
));
OF_CUDA_CHECK
(
hipEventRecord
(
cuda_memory_bandwidth_profile_start_event
,
cuda_stream
->
cuda_stream
()));
}
}
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_PUSH
(
kernel
->
op_conf
().
name
());
}
#endif // WITH_ROCM
}
void
TraceKernelForwardDataContentEnd
(
KernelContext
*
kernel_ctx
,
const
Kernel
*
kernel
)
{
...
...
@@ -103,6 +128,45 @@ void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* k
}
}
#endif // WITH_CUDA
#if defined(WITH_ROCM)
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_POP
();
}
// The memory bandwidth profiler only works in lazy mode.
if
(
profile_cuda_memory_bandwidth
)
{
auto
*
cuda_stream
=
dynamic_cast
<
ep
::
CudaStream
*>
(
kernel_ctx
->
stream
());
auto
*
actor_context_provider
=
dynamic_cast
<
ActorContextProvider
*>
(
kernel_ctx
);
if
(
cuda_stream
!=
nullptr
&&
actor_context_provider
!=
nullptr
)
{
hipEvent_t
start_event
=
cuda_memory_bandwidth_profile_start_event
;
hipEvent_t
end_event
=
cuda_memory_bandwidth_profile_end_event
;
cuda_memory_bandwidth_profile_start_event
=
nullptr
;
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
CHECK_NOTNULL
(
start_event
);
CHECK_NOTNULL
(
end_event
);
OF_CUDA_CHECK
(
hipEventRecord
(
end_event
,
cuda_stream
->
cuda_stream
()));
int64_t
memory_size
=
0
;
for
(
const
auto
&
bn
:
kernel
->
op_attribute
().
input_bns
())
{
const
Blob
*
blob
=
kernel_ctx
->
BnInOp2Blob
(
bn
);
if
(
blob
)
{
memory_size
+=
blob
->
ByteSizeOfBlobBody
();
}
}
for
(
const
auto
&
bn
:
kernel
->
op_attribute
().
output_bns
())
{
const
Blob
*
blob
=
kernel_ctx
->
BnInOp2Blob
(
bn
);
if
(
blob
)
{
memory_size
+=
blob
->
ByteSizeOfBlobBody
();
}
}
const
std
::
string
op_name
=
kernel
->
op_conf
().
name
();
actor_context_provider
->
GetActorContext
()
->
AddCallback
(
[
start_event
,
end_event
,
memory_size
,
op_name
]()
{
float
elapsed_ms
=
0
;
OF_CUDA_CHECK
(
hipEventElapsedTime
(
&
elapsed_ms
,
start_event
,
end_event
));
OF_CUDA_CHECK
(
hipEventDestroy
(
start_event
));
OF_CUDA_CHECK
(
hipEventDestroy
(
end_event
));
double
bandwidth
=
static_cast
<
double
>
(
memory_size
)
/
(
1024.0
*
1024.0
*
1024.0
)
/
(
elapsed_ms
/
1000
);
LOG
(
INFO
)
<<
"PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: "
<<
op_name
<<
" elapsed(ms): "
<<
elapsed_ms
<<
" memory_size(Byte): "
<<
memory_size
<<
" bandwidth(GB/s): "
<<
bandwidth
;
});
}
}
#endif // WITH_ROCM
}
}
// namespace profiler
...
...
oneflow/core/profiler/kineto_shim.cpp
View file @
f262efc9
...
...
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
#include "oneflow/core/profiler/kineto_shim.h"
#include "libkineto.h"
...
...
oneflow/core/profiler/kineto_shim.h
View file @
f262efc9
...
...
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
#include <string>
#include <memory>
...
...
oneflow/core/profiler/profile_manager.cpp
View file @
f262efc9
...
...
@@ -15,12 +15,12 @@ limitations under the License.
*/
#include <memory>
#include <unordered_map>
//
#include "fmt/core.h"
#include "fmt/core.h"
#include "nlohmann/json.hpp"
#include "oneflow/core/profiler/kineto_shim.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event.h"
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
#include <libkineto.h>
#endif // WITH_CUDA
...
...
@@ -48,7 +48,7 @@ std::string ProfileManager::DumpResultsJson() {
}
std
::
vector
<
std
::
shared_ptr
<
IEvent
>>
ProfileManager
::
ExportEvents
()
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
auto
trace
=
StopTrace
();
const
auto
&
kineto_events
=
*
(
trace
.
get
()
->
activities
());
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
custom_events
;
...
...
@@ -77,7 +77,7 @@ std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
while
(
!
events_
.
empty
())
{
auto
evt
=
events_
.
front
();
events_
.
pop
();
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
auto
evt_kernel
=
std
::
dynamic_pointer_cast
<
KernelEvent
>
(
evt
);
if
(
evt_kernel
)
{
std
::
set
<
int64_t
>
current_corr_ids
;
...
...
@@ -106,8 +106,7 @@ std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) {
}
else
{
event_recorders_last_id_
[
name
]
++
;
}
// return fmt::format("{}.{}", name, event_recorders_last_id_[name]);
return
"yuguo"
;
return
fmt
::
format
(
"{}.{}"
,
name
,
event_recorders_last_id_
[
name
]);
}
}
// namespace profiler
...
...
oneflow/core/profiler/profile_manager.h
View file @
f262efc9
...
...
@@ -37,7 +37,7 @@ class ProfileManager {
use_cuda_
(
use_cuda
),
record_shapes_
(
record_shapes
),
record_bandwidth_
(
record_bandwidth
)
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
std
::
set
<
ActivityType
>
activities
{};
if
(
use_cpu
)
{
activities
.
insert
(
ActivityType
::
CPU
);
}
if
(
use_cuda
)
{
activities
.
insert
(
ActivityType
::
CUDA
);
}
...
...
oneflow/core/profiler/profiler.cpp
View file @
f262efc9
...
...
@@ -20,11 +20,20 @@ limitations under the License.
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/vm/vm_util.h"
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_profile.h>
#include <roctracer_roctx.h>
#include <sys/syscall.h>
#include <iostream>
#include "oneflow/core/device/cuda_util.h"
#else
#include <nvtx3/nvToolsExt.h>
#include <sys/syscall.h>
#include <iostream>
#include <cuda_profiler_api.h>
#include "oneflow/core/device/cuda_util.h"
#endif
#endif // OF_ENABLE_PROFILER
namespace
oneflow
{
...
...
@@ -33,6 +42,16 @@ namespace profiler {
void
NameThisHostThread
(
const
std
::
string
&
name
)
{
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
static
thread_local
std
::
unique_ptr
<
std
::
string
>
thread_name_prefix
;
if
(
!
thread_name_prefix
)
{
thread_name_prefix
.
reset
(
new
std
::
string
(
GetStringFromEnv
(
"ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX"
,
""
)));
}
const
std
::
string
name_with_prefix
=
*
thread_name_prefix
+
name
;
// nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());
roctxMarkA
(
name_with_prefix
.
c_str
());
#else
static
thread_local
std
::
unique_ptr
<
std
::
string
>
thread_name_prefix
;
if
(
!
thread_name_prefix
)
{
thread_name_prefix
.
reset
(
...
...
@@ -40,18 +59,27 @@ void NameThisHostThread(const std::string& name) {
}
const
std
::
string
name_with_prefix
=
*
thread_name_prefix
+
name
;
nvtxNameOsThreadA
(
syscall
(
SYS_gettid
),
name_with_prefix
.
c_str
());
#endif
#endif // OF_ENABLE_PROFILER
}
void
RangePush
(
const
std
::
string
&
name
)
{
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePushA
(
name
.
c_str
());
#else
nvtxRangePushA
(
name
.
c_str
());
#endif
#endif // OF_ENABLE_PROFILER
}
void
RangePop
()
{
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePop
();
#else
nvtxRangePop
();
#endif
#endif // OF_ENABLE_PROFILER
}
...
...
@@ -82,13 +110,21 @@ void LogHostMemoryUsage(const std::string& name) {
void
ProfilerStart
()
{
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipProfilerStart
());
#else
OF_CUDA_CHECK
(
cudaProfilerStart
());
#endif
#endif // OF_ENABLE_PROFILER
}
void
ProfilerStop
()
{
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipProfilerStop
());
#else
OF_CUDA_CHECK
(
cudaProfilerStop
());
#endif
#endif // OF_ENABLE_PROFILER
}
...
...
@@ -105,6 +141,9 @@ Maybe<std::string> DisableProfilerAndReturnResult() {
#if defined(WITH_CUDA)
OF_CUDA_CHECK
(
cudaDeviceSynchronize
());
#endif // WITH_CUDA
#if defined(WITH_ROCM)
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
#endif // WITH_ROCM
auto
*
pmgr
=
JUST
(
SingletonMaybe
<
ProfileManager
>
());
std
::
string
results
=
pmgr
->
DumpResultsJson
();
Singleton
<
ProfileManager
>::
Delete
();
...
...
oneflow/user/kernels/math_unary_elementwise_func.h
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/core/device/cuda_pseudo_half.h"
#if defined(__CUDACC__)
#include <cuda_fp16.h>
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#elif defined(__HIPCC__)
#include <cmath>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#else
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
#else
#include <cmath>
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
namespace
oneflow
{
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
template<typename T> \
struct func_prefix##Functor;
OF_PP_FOR_EACH_TUPLE
(
DECLARE_UNARY_FUNCTOR
,
MATH_UNARY_ELEMENTWISE_FUNC_SEQ
)
template
<
typename
T
>
struct
AbsFunctor
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
if
(
x
==
T
(
0
))
return
T
(
0
);
else
return
x
<
T
(
0
)
?
-
x
:
x
;
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
if
(
x
==
T
(
0
))
return
T
(
0
);
else
return
x
<
T
(
0
)
?
-
dy
:
dy
;
}
};
template
<
typename
T
>
struct
SignFunctor
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
return
(
T
(
0
)
<
x
)
-
(
x
<
T
(
0
));
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
return
T
(
0
);
}
};
template
<
>
struct
RsqrtFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
#if defined(__CUDACC__)
return
rsqrtf
(
x
);
#elif defined(__HIP_DEVICE_COMPILE__)
return
rsqrtf
(
x
);
#else
return
1.0
f
/
std
::
sqrt
(
x
);
#endif
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
1.0
f
/
(
2.0
f
*
MATH_FUNC_F
(
sqrt
,
x
*
x
*
x
)));
}
};
template
<
>
struct
RsqrtFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
#if defined(__CUDACC__)
return
rsqrt
(
x
);
#elif defined(__HIP_DEVICE_COMPILE__)
return
rsqrt
(
x
);
#else
return
1.0
/
std
::
sqrt
(
x
);
#endif
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
1.0
/
(
2.0
*
MATH_FUNC_D
(
sqrt
,
x
*
x
*
x
)));
}
};
// float version
template
<
>
struct
AcosFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
-
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
}
};
template
<
>
struct
AcoshFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
x
*
x
-
1.0
f
);
}
};
template
<
>
struct
AsinFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
}
};
template
<
>
struct
AsinhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
+
x
*
x
);
}
};
template
<
>
struct
AtanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
1.0
f
+
x
*
x
));
}
};
template
<
>
struct
AtanhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
1.0
f
-
x
*
x
));
}
};
template
<
>
struct
NotEqualZeroFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
CeilFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
CosFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
MATH_FUNC_F
(
sin
,
x
));
}
};
template
<
>
struct
CoshFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
sinh
,
x
);
}
};
template
<
>
struct
ErfFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ErfcFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
-
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ExpFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
}
};
template
<
>
struct
Expm1Functor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
}
};
template
<
>
struct
FloorFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
LgammaFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
return
0.0
f
;
}
};
template
<
>
struct
LogFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
x
);
}
};
template
<
>
struct
Log2Functor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
x
*
MATH_FUNC_F
(
log
,
2.0
f
)));
}
};
template
<
>
struct
Log1pFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
x
+
1.0
f
));
}
};
template
<
>
struct
LogSigmoidFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
MATH_FUNC_F
(
log
,
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
)));
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
exp
,
x
)
+
1.0
f
));
}
};
template
<
>
struct
NegativeFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
-
dy
;
}
};
template
<
>
struct
ReciprocalFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
}
};
template
<
>
struct
ReciprocalNoNanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
return
1.0
f
/
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
}
};
template
<
>
struct
RintFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
RoundFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
SigmoidFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
float
y
=
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
return
dy
*
(
y
*
(
1.0
f
-
y
));
}
};
template
<
>
struct
SinFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
cos
,
x
);
}
};
template
<
>
struct
SinhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
cosh
,
x
);
}
};
template
<
>
struct
SqrtFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
0.5
f
/
MATH_FUNC_F
(
sqrt
,
x
);
}
};
template
<
>
struct
SquareFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
x
;
}
};
template
<
>
struct
TanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
cos
,
x
)
*
MATH_FUNC_F
(
cos
,
x
)));
}
};
// double version
template
<
>
struct
AcosFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
}
};
template
<
>
struct
AcoshFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
x
*
x
-
1.0
);
}
};
template
<
>
struct
AsinFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
}
};
template
<
>
struct
AsinhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
+
x
*
x
);
}
};
template
<
>
struct
AtanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
1.0
+
x
*
x
));
}
};
template
<
>
struct
AtanhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
1.0
-
x
*
x
));
}
};
template
<
>
struct
NotEqualZeroFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
CeilFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
CosFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
MATH_FUNC_D
(
sin
,
x
));
}
};
template
<
>
struct
CoshFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
sinh
,
x
);
}
};
template
<
>
struct
ErfFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ErfcFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ExpFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
}
};
template
<
>
struct
Expm1Functor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
}
};
template
<
>
struct
FloorFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
LgammaFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
return
0.0
;
}
};
template
<
>
struct
LogFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
x
);
}
};
template
<
>
struct
Log2Functor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
x
*
MATH_FUNC_D
(
log
,
2.0
)));
}
};
template
<
>
struct
Log1pFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
x
+
1.0
));
}
};
template
<
>
struct
LogSigmoidFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
MATH_FUNC_D
(
log
,
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
)));
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
exp
,
x
)
+
1.0
));
}
};
template
<
>
struct
NegativeFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
-
dy
;
}
};
template
<
>
struct
ReciprocalFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
1.0
/
(
x
*
x
));
}
};
template
<
>
struct
ReciprocalNoNanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
return
1.0
/
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
return
dy
*
(
-
1.0
/
(
x
*
x
));
}
};
template
<
>
struct
RintFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
RoundFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
SigmoidFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
double
y
=
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
return
dy
*
(
y
*
(
1.0
-
y
));
}
};
template
<
>
struct
SinFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
cos
,
x
);
}
};
template
<
>
struct
SinhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
cosh
,
x
);
}
};
template
<
>
struct
SqrtFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
double
)
0.5
/
MATH_FUNC_D
(
sqrt
,
x
);
}
};
template
<
>
struct
SquareFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
x
;
}
};
template
<
>
struct
TanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
cos
,
x
)
*
MATH_FUNC_D
(
cos
,
x
)));
}
};
#if defined(__CUDACC__) || defined(__HIPCC__)
// half version
#define OF_HALF_FUNC __device__ __forceinline__
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
template
<
>
struct
AbsFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
x
)
:
x
;
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
dy
)
:
dy
;
}
};
template
<
>
struct
AcosFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acos
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
)))));
}
};
template
<
>
struct
AcoshFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acosh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
__hmul
(
x
,
x
),
GetOneVal
<
half
>
())));
}
};
template
<
>
struct
AsinFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asin
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
AsinhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asinh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
AtanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atan
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
AtanhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atanh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
CeilFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hceil
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
NotEqualZeroFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
x
!=
static_cast
<
half
>
(
0.0
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
CosFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hcos
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hsin
(
x
)));
}
};
template
<
>
struct
CoshFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
cosh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
MATH_FUNC_H
(
sinh
,
x
));
}
};
template
<
>
struct
ErfFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erf
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
))));
}
};
template
<
>
struct
ErfcFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erfc
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
)))));
}
};
template
<
>
struct
ExpFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hexp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
};
template
<
>
struct
Expm1Functor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
expm1
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
};
template
<
>
struct
FloorFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hfloor
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
LgammaFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
lgamma
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
LogFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
x
));
}
};
template
<
>
struct
Log2Functor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog2
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hmul
(
x
,
hlog
(
HALF_VAL_TWO
))));
}
};
template
<
>
struct
Log1pFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
log1p
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hadd
(
x
,
GetOneVal
<
half
>
())));
}
};
template
<
>
struct
LogSigmoidFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
hlog
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
)))));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hadd
(
hexp
(
x
),
GetOneVal
<
half
>
())));
}
};
template
<
>
struct
NegativeFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hneg
(
dy
);
}
};
template
<
>
struct
ReciprocalFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
ReciprocalNoNanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
return
hrcp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
RintFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrint
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
RoundFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
nearbyint
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
RsqrtFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
HALF_VAL_TWO
,
hsqrt
(
__hmul
(
x
,
__hmul
(
x
,
x
)))))));
}
};
template
<
>
struct
SigmoidFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
half
y
=
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
return
__hmul
(
dy
,
__hmul
(
y
,
__hsub
(
GetOneVal
<
half
>
(),
y
)));
}
};
template
<
>
struct
SignFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
if
(
__hgt
(
x
,
GetZeroVal
<
half
>
()))
{
return
GetOneVal
<
half
>
();
}
if
(
__hlt
(
x
,
GetZeroVal
<
half
>
()))
{
return
__hneg
(
GetOneVal
<
half
>
());
}
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
SinFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsin
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hcos
(
x
));
}
};
template
<
>
struct
SinhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
sinh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
MATH_FUNC_H
(
cosh
,
x
));
}
};
template
<
>
struct
SqrtFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
HALF_VAL_HALF
,
hsqrt
(
x
)));
}
};
template
<
>
struct
SquareFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hmul
(
x
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_TWO
,
x
));
}
};
template
<
>
struct
TanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hdiv
(
hsin
(
x
),
hcos
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hmul
(
hcos
(
x
),
hcos
(
x
))));
}
};
#endif
}
// namespace oneflow
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/core/device/cuda_pseudo_half.h"
#if defined(__CUDACC__)
#include <cuda_fp16.h>
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#elif defined(__HIPCC__)
#include <cmath>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#else
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
#else
#include <cmath>
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
namespace
oneflow
{
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
template
<
typename
T
>
\
struct
func_prefix
##
Functor
;
OF_PP_FOR_EACH_TUPLE
(
DECLARE_UNARY_FUNCTOR
,
MATH_UNARY_ELEMENTWISE_FUNC_SEQ
)
template
<
typename
T
>
struct
AbsFunctor
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
if
(
x
==
T
(
0
))
return
T
(
0
);
else
return
x
<
T
(
0
)
?
-
x
:
x
;
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
if
(
x
==
T
(
0
))
return
T
(
0
);
else
return
x
<
T
(
0
)
?
-
dy
:
dy
;
}
};
template
<
typename
T
>
struct
SignFunctor
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
return
(
T
(
0
)
<
x
)
-
(
x
<
T
(
0
));
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
return
T
(
0
);
}
};
template
<
>
struct
RsqrtFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
#if defined(__CUDACC__)
return
rsqrtf
(
x
);
#elif defined(__HIP_DEVICE_COMPILE__)
return
rsqrtf
(
x
);
#else
return
1.0
f
/
std
::
sqrt
(
x
);
#endif
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
1.0
f
/
(
2.0
f
*
MATH_FUNC_F
(
sqrt
,
x
*
x
*
x
)));
}
};
template
<
>
struct
RsqrtFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
#if defined(__CUDACC__)
return
rsqrt
(
x
);
#elif defined(__HIP_DEVICE_COMPILE__)
return
rsqrt
(
x
);
#else
return
1.0
/
std
::
sqrt
(
x
);
#endif
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
1.0
/
(
2.0
*
MATH_FUNC_D
(
sqrt
,
x
*
x
*
x
)));
}
};
// float version
template
<
>
struct
AcosFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
-
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
}
};
template
<
>
struct
AcoshFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
x
*
x
-
1.0
f
);
}
};
template
<
>
struct
AsinFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
}
};
template
<
>
struct
AsinhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
+
x
*
x
);
}
};
template
<
>
struct
AtanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
1.0
f
+
x
*
x
));
}
};
template
<
>
struct
AtanhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
1.0
f
-
x
*
x
));
}
};
template
<
>
struct
NotEqualZeroFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
CeilFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
CosFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
MATH_FUNC_F
(
sin
,
x
));
}
};
template
<
>
struct
CoshFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
sinh
,
x
);
}
};
template
<
>
struct
ErfFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ErfcFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
-
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ExpFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
}
};
template
<
>
struct
Expm1Functor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
}
};
template
<
>
struct
FloorFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
LgammaFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
//
assert(false);
return
0.0
f
;
}
};
template
<
>
struct
LogFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
x
);
}
};
template
<
>
struct
Log2Functor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
x
*
MATH_FUNC_F
(
log
,
2.0
f
)));
}
};
template
<
>
struct
Log1pFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
x
+
1.0
f
));
}
};
template
<
>
struct
LogSigmoidFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
MATH_FUNC_F
(
log
,
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
)));
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
exp
,
x
)
+
1.0
f
));
}
};
template
<
>
struct
NegativeFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
-
dy
;
}
};
template
<
>
struct
ReciprocalFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
}
};
template
<
>
struct
ReciprocalNoNanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
return
1.0
f
/
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
}
};
template
<
>
struct
RintFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
RoundFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
SigmoidFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
float
y
=
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
return
dy
*
(
y
*
(
1.0
f
-
y
));
}
};
template
<
>
struct
SinFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
cos
,
x
);
}
};
template
<
>
struct
SinhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
cosh
,
x
);
}
};
template
<
>
struct
SqrtFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
0.5
f
/
MATH_FUNC_F
(
sqrt
,
x
);
}
};
template
<
>
struct
SquareFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
x
;
}
};
template
<
>
struct
TanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
cos
,
x
)
*
MATH_FUNC_F
(
cos
,
x
)));
}
};
// double version
template
<
>
struct
AcosFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
}
};
template
<
>
struct
AcoshFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
x
*
x
-
1.0
);
}
};
template
<
>
struct
AsinFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
}
};
template
<
>
struct
AsinhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
+
x
*
x
);
}
};
template
<
>
struct
AtanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
1.0
+
x
*
x
));
}
};
template
<
>
struct
AtanhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
1.0
-
x
*
x
));
}
};
template
<
>
struct
NotEqualZeroFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
f
;
}
};
template
<
>
struct
CeilFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
CosFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
MATH_FUNC_D
(
sin
,
x
));
}
};
template
<
>
struct
CoshFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
sinh
,
x
);
}
};
template
<
>
struct
ErfFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ErfcFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
};
template
<
>
struct
ExpFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
}
};
template
<
>
struct
Expm1Functor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
}
};
template
<
>
struct
FloorFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
LgammaFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
//
assert(false);
return
0.0
;
}
};
template
<
>
struct
LogFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
x
);
}
};
template
<
>
struct
Log2Functor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
x
*
MATH_FUNC_D
(
log
,
2.0
)));
}
};
template
<
>
struct
Log1pFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
x
+
1.0
));
}
};
template
<
>
struct
LogSigmoidFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
MATH_FUNC_D
(
log
,
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
)));
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
exp
,
x
)
+
1.0
));
}
};
template
<
>
struct
NegativeFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
-
dy
;
}
};
template
<
>
struct
ReciprocalFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
1.0
/
(
x
*
x
));
}
};
template
<
>
struct
ReciprocalNoNanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
return
1.0
/
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
return
dy
*
(
-
1.0
/
(
x
*
x
));
}
};
template
<
>
struct
RintFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
RoundFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
template
<
>
struct
SigmoidFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
double
y
=
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
return
dy
*
(
y
*
(
1.0
-
y
));
}
};
template
<
>
struct
SinFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
cos
,
x
);
}
};
template
<
>
struct
SinhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
cosh
,
x
);
}
};
template
<
>
struct
SqrtFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
double
)
0.5
/
MATH_FUNC_D
(
sqrt
,
x
);
}
};
template
<
>
struct
SquareFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
x
;
}
};
template
<
>
struct
TanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
cos
,
x
)
*
MATH_FUNC_D
(
cos
,
x
)));
}
};
#if defined(__CUDACC__) || defined(__HIPCC__)
// half version
#define OF_HALF_FUNC __device__ __forceinline__
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
template
<
>
struct
AbsFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
x
)
:
x
;
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
dy
)
:
dy
;
}
};
template
<
>
struct
AcosFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acos
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
)))));
}
};
template
<
>
struct
AcoshFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acosh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
__hmul
(
x
,
x
),
GetOneVal
<
half
>
())));
}
};
template
<
>
struct
AsinFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asin
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
AsinhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asinh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
AtanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atan
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
AtanhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atanh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
CeilFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hceil
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
NotEqualZeroFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
x
!=
static_cast
<
half
>
(
0.0
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
CosFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hcos
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hsin
(
x
)));
}
};
template
<
>
struct
CoshFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
cosh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
MATH_FUNC_H
(
sinh
,
x
));
}
};
template
<
>
struct
ErfFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erf
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
))));
}
};
template
<
>
struct
ErfcFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erfc
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
)))));
}
};
template
<
>
struct
ExpFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hexp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
};
template
<
>
struct
Expm1Functor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
expm1
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
};
template
<
>
struct
FloorFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hfloor
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
LgammaFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
lgamma
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
//
assert(false);
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
LogFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
x
));
}
};
template
<
>
struct
Log2Functor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog2
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hmul
(
x
,
hlog
(
HALF_VAL_TWO
))));
}
};
template
<
>
struct
Log1pFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
log1p
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hadd
(
x
,
GetOneVal
<
half
>
())));
}
};
template
<
>
struct
LogSigmoidFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
hlog
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
)))));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hadd
(
hexp
(
x
),
GetOneVal
<
half
>
())));
}
};
template
<
>
struct
NegativeFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hneg
(
dy
);
}
};
template
<
>
struct
ReciprocalFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
ReciprocalNoNanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
return
hrcp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
}
};
template
<
>
struct
RintFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrint
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
RoundFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
nearbyint
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
RsqrtFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
HALF_VAL_TWO
,
hsqrt
(
__hmul
(
x
,
__hmul
(
x
,
x
)))))));
}
};
template
<
>
struct
SigmoidFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
half
y
=
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
return
__hmul
(
dy
,
__hmul
(
y
,
__hsub
(
GetOneVal
<
half
>
(),
y
)));
}
};
template
<
>
struct
SignFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
if
(
__hgt
(
x
,
GetZeroVal
<
half
>
()))
{
return
GetOneVal
<
half
>
();
}
if
(
__hlt
(
x
,
GetZeroVal
<
half
>
()))
{
return
__hneg
(
GetOneVal
<
half
>
());
}
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
template
<
>
struct
SinFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsin
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hcos
(
x
));
}
};
template
<
>
struct
SinhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
sinh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
MATH_FUNC_H
(
cosh
,
x
));
}
};
template
<
>
struct
SqrtFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
HALF_VAL_HALF
,
hsqrt
(
x
)));
}
};
template
<
>
struct
SquareFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hmul
(
x
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_TWO
,
x
));
}
};
template
<
>
struct
TanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hdiv
(
hsin
(
x
),
hcos
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hmul
(
hcos
(
x
),
hcos
(
x
))));
}
};
#endif
}
// namespace oneflow
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
oneflow/user/kernels/nvtx_range_kernel.hip.cpp
0 → 100644
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#ifdef OF_ENABLE_PROFILER
#include <roctracer_roctx.h>
#endif // OF_ENABLE_PROFILER
namespace
oneflow
{
namespace
{
#ifdef OF_ENABLE_PROFILER
static
thread_local
HashMap
<
std
::
string
,
roctx_range_id_t
>
mark2range_id
;
#endif
}
// namespace
class
NvtxOpKernelState
final
:
public
user_op
::
OpKernelState
{
public:
NvtxOpKernelState
()
:
counter_
(
0
)
{
#ifndef OF_ENABLE_PROFILER
LOG
(
WARNING
)
<<
"To use NVTX, run cmake with -DBUILD_PROFILER=ON"
;
#endif
}
~
NvtxOpKernelState
()
override
=
default
;
int64_t
counter
()
const
{
return
counter_
;
}
void
IncreaseCount
()
{
counter_
+=
1
;
}
private:
int64_t
counter_
;
};
class
NvtxStartKernel
final
:
public
user_op
::
OpKernel
{
public:
NvtxStartKernel
()
=
default
;
~
NvtxStartKernel
()
override
=
default
;
std
::
shared_ptr
<
user_op
::
OpKernelState
>
CreateOpKernelState
(
user_op
::
KernelInitContext
*
ctx
)
const
override
{
return
std
::
make_shared
<
NvtxOpKernelState
>
();
}
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
)
const
override
{
const
user_op
::
Tensor
*
in
=
ctx
->
Tensor4ArgNameAndIndex
(
"in"
,
0
);
user_op
::
Tensor
*
out
=
ctx
->
Tensor4ArgNameAndIndex
(
"out"
,
0
);
const
ShapeView
&
in_shape
=
in
->
shape_view
();
CHECK_EQ
(
out
->
shape_view
(),
in_shape
);
const
DataType
in_data_type
=
in
->
data_type
();
CHECK_EQ
(
out
->
data_type
(),
in_data_type
);
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
out
->
mut_dptr
<
void
>
(),
in
->
dptr
<
void
>
(),
in_shape
.
elem_cnt
()
*
GetSizeOfDataType
(
in_data_type
));
#ifdef OF_ENABLE_PROFILER
auto
*
kernel_state
=
dynamic_cast
<
NvtxOpKernelState
*>
(
state
);
const
std
::
string
mark_prefix
=
ctx
->
Attr
<
std
::
string
>
(
"mark_prefix"
);
const
std
::
string
mark
=
mark_prefix
+
"-"
+
std
::
to_string
(
kernel_state
->
counter
());
roctx_range_id_t
range_id
=
roctxRangeStartA
(
mark
.
c_str
());
CHECK
(
mark2range_id
.
emplace
(
mark
,
range_id
).
second
);
kernel_state
->
IncreaseCount
();
#endif // OF_ENABLE_PROFILER
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"nvtx_start"
)
.
SetCreateFn
<
NvtxStartKernel
>
()
.
SetIsMatchedHob
(
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
,
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
OF_RETURN_IF_ERROR
(
AddInplaceArgPairFn
(
"out"
,
0
,
"in"
,
0
,
false
));
return
Maybe
<
void
>::
Ok
();
});
class
NvtxEndKernel
final
:
public
user_op
::
OpKernel
{
public:
NvtxEndKernel
()
=
default
;
~
NvtxEndKernel
()
override
=
default
;
std
::
shared_ptr
<
user_op
::
OpKernelState
>
CreateOpKernelState
(
user_op
::
KernelInitContext
*
ctx
)
const
override
{
return
std
::
make_shared
<
NvtxOpKernelState
>
();
}
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
)
const
override
{
const
user_op
::
Tensor
*
in
=
ctx
->
Tensor4ArgNameAndIndex
(
"in"
,
0
);
user_op
::
Tensor
*
out
=
ctx
->
Tensor4ArgNameAndIndex
(
"out"
,
0
);
const
ShapeView
&
in_shape
=
in
->
shape_view
();
CHECK_EQ
(
out
->
shape_view
(),
in_shape
);
const
DataType
in_data_type
=
in
->
data_type
();
CHECK_EQ
(
out
->
data_type
(),
in_data_type
);
#ifdef OF_ENABLE_PROFILER
auto
*
kernel_state
=
dynamic_cast
<
NvtxOpKernelState
*>
(
state
);
const
std
::
string
mark_prefix
=
ctx
->
Attr
<
std
::
string
>
(
"mark_prefix"
);
const
std
::
string
mark
=
mark_prefix
+
"-"
+
std
::
to_string
(
kernel_state
->
counter
());
auto
it
=
mark2range_id
.
find
(
mark
.
c_str
());
CHECK
(
it
!=
mark2range_id
.
end
());
roctx_range_id_t
range_id
=
it
->
second
;
mark2range_id
.
erase
(
it
);
roctxRangeStop
(
range_id
);
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
out
->
mut_dptr
<
void
>
(),
in
->
dptr
<
void
>
(),
in_shape
.
elem_cnt
()
*
GetSizeOfDataType
(
in_data_type
));
kernel_state
->
IncreaseCount
();
#endif
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"nvtx_end"
)
.
SetCreateFn
<
NvtxEndKernel
>
()
.
SetIsMatchedHob
(
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
,
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
OF_RETURN_IF_ERROR
(
AddInplaceArgPairFn
(
"out"
,
0
,
"in"
,
0
,
false
));
return
Maybe
<
void
>::
Ok
();
});
}
// namespace oneflow
oneflow/user/kernels/stateful_opkernel.cpp
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/eager/call_context.h"
namespace
oneflow
{
namespace
one
{
class
ConsistentTensorInferResult
;
using
ArgVec
=
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>
;
using
EagerBlobObjectListRawPtr
=
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>*
;
using
ConsistentTensorInferResultRawPtr
=
const
ConsistentTensorInferResult
*
;
class
ZeroCopyBaseContextHelper
{
public:
ZeroCopyBaseContextHelper
(
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
input_arg_tuple_
(
input_arg_tuple
),
output_arg_tuple_
(
output_arg_tuple
)
{}
#define RETURN_IF_FOUND(inputs, outputs, post_action) \
int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \
arg_name, index); \
if (i >= 0) { return (inputs).at(i) post_action; } \
i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \
index); \
if (i >= 0) { return (outputs).at(i) post_action; }
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
return
nullptr
;
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
if
(
arg_name
==
"tmp_buffer"
&&
index
==
0
)
{
return
call_ctx
->
mut_tmp_tensor
();
}
return
nullptr
;
}
const
ConsistentTensorMeta
*
ConsistentTensorMeta4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
RETURN_IF_FOUND
(
consistent_tensor_infer_result
->
input_tensor_metas
(),
consistent_tensor_infer_result
->
output_tensor_metas
(),
.
shared_from_symbol
().
get
());
return
nullptr
;
}
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
if
(
!
consistent_tensor_infer_result
)
{
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
if
(
!
consistent_tensor_infer_result
->
input_tensor_metas
().
empty
())
{
return
consistent_tensor_infer_result
->
input_tensor_metas
().
at
(
0
)
->
parallel_desc
();
}
else
if
(
!
consistent_tensor_infer_result
->
output_tensor_metas
().
empty
())
{
return
consistent_tensor_infer_result
->
output_tensor_metas
().
at
(
0
)
->
parallel_desc
();
}
else
{
UNIMPLEMENTED
();
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
auto
&
parallel_desc
=
this
->
parallel_desc
(
call_ctx
);
if
(
parallel_desc
.
has_value
())
{
const
auto
&
parallel_desc_symbol
=
CHECK_JUST
(
parallel_desc
);
return
*
CHECK_JUST
(
GetParallelContext4CurrentProcessCtx
(
parallel_desc_symbol
));
}
else
{
static
ParallelContext
single_device_parallel_ctx
(
MakeSingleDeviceParallelCtx
());
return
single_device_parallel_ctx
;
}
}
const
ArgVec
&
inputs
()
const
{
return
input_arg_tuple_
->
indexed_arg_name_and_index
();
}
const
ArgVec
&
outputs
()
const
{
return
output_arg_tuple_
->
indexed_arg_name_and_index
();
}
private:
static
int32_t
TryGetTensorTupleIndex
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int32_t
>>&
arg_name2bn_index2tensor_tuple_index
,
const
std
::
string
&
arg_name
,
const
int32_t
arg_index
)
{
auto
it
=
arg_name2bn_index2tensor_tuple_index
.
find
(
arg_name
);
if
(
it
!=
arg_name2bn_index2tensor_tuple_index
.
end
())
{
return
it
->
second
.
at
(
arg_index
);
}
return
-
1
;
}
static
ParallelContext
MakeSingleDeviceParallelCtx
()
{
ParallelContext
single_device_parallel_ctx
;
single_device_parallel_ctx
.
set_parallel_id
(
0
);
single_device_parallel_ctx
.
set_parallel_num
(
1
);
return
single_device_parallel_ctx
;
}
std
::
shared_ptr
<
const
ArgTuple
>
input_arg_tuple_
;
std
::
shared_ptr
<
const
ArgTuple
>
output_arg_tuple_
;
};
class
UserKernelBaseContextHelper
final
:
public
ZeroCopyBaseContextHelper
{
public:
UserKernelBaseContextHelper
(
DeviceType
device_type
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
ZeroCopyBaseContextHelper
(
input_arg_tuple
,
output_arg_tuple
),
device_type_
(
device_type
)
{}
~
UserKernelBaseContextHelper
()
=
default
;
DeviceType
device_type
()
const
{
return
device_type_
;
}
const
JobDesc
&
job_desc
()
const
{
UNIMPLEMENTED
();
return
*
(
const
JobDesc
*
)
nullptr
;
}
private:
const
DeviceType
device_type_
;
};
class
UserOpInferContextHelper
final
{
public:
UserOpInferContextHelper
(
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
zero_copy_base_ctx_helper_
(
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserOpInferContextHelper
()
=
default
;
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
UNIMPLEMENTED
();
return
nullptr
;
}
const
user_op
::
TensorDesc
&
InputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
));
}
user_op
::
TensorDesc
*
OutputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
zero_copy_base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
Shape
&
InputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Shape
*
OutputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Shape
*
Shape4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_shape
();
}
const
Stride
&
InputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Stride
*
OutputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Stride
*
Stride4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_stride
();
}
const
DataType
&
InputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
DataType
*
OutputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
DataType
*
Dtype4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_data_type
();
}
bool
InputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
bool
*
OutputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
bool
*
IsDynamic4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_is_dynamic
();
}
const
ArgVec
&
inputs
()
const
{
return
zero_copy_base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
zero_copy_base_ctx_helper_
.
outputs
();
}
const
JobDesc
*
job_desc
()
const
{
UNIMPLEMENTED
();
return
nullptr
;
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
zero_copy_base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
return
*
CHECK_JUST
(
zero_copy_base_ctx_helper_
.
parallel_desc
(
call_ctx
));
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
return
nd_sbp
.
sbp_parallel
(
0
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
zero_copy_base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
->
nd_sbp
();
}
int64_t
parallel_num
(
eager
::
CallContext
*
call_ctx
)
const
{
return
parallel_ctx
(
call_ctx
).
parallel_num
();
}
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
input
(
arg_name
,
index
);
}
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
output
(
arg_name
,
index
);
}
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
has_input
(
arg_name
,
index
);
}
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
has_output
(
arg_name
,
index
);
}
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
{
return
user_op_conf
().
input_size
(
arg_name
);
}
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
{
return
user_op_conf
().
output_size
(
arg_name
);
}
const
std
::
string
&
op_name
()
const
{
return
user_op_conf
().
op_name
();
}
const
std
::
string
&
op_type_name
()
const
{
return
user_op_conf
().
op_type_name
();
}
const
std
::
string
&
op_loc
()
const
{
return
user_op_conf_
->
op_conf
().
loc
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
private:
user_op
::
TensorDesc
*
NonNullTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
user_op
::
TensorDesc
*
tensor_desc
=
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
if
(
!
tensor_desc
)
{
LOG
(
FATAL
)
<<
"Arg ("
<<
arg_name
<<
","
<<
index
<<
") is not found"
;
}
return
tensor_desc
;
}
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
ZeroCopyBaseContextHelper
zero_copy_base_ctx_helper_
;
};
class
UserOpInferContext
:
public
user_op
::
InferContext
{
public:
UserOpInferContext
(
const
UserOpInferContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
~
UserOpInferContext
()
override
=
default
;
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
user_op
::
TensorDesc
&
InputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
}
user_op
::
TensorDesc
*
OutputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
Shape
&
InputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputShape
(
call_ctx_
,
arg_name
,
index
);
}
Shape
*
OutputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputShape
(
call_ctx_
,
arg_name
,
index
);
}
Shape
*
Shape4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Shape4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
Stride
&
InputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputStride
(
call_ctx_
,
arg_name
,
index
);
}
Stride
*
OutputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputStride
(
call_ctx_
,
arg_name
,
index
);
}
Stride
*
Stride4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Stride4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
DataType
&
InputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputDType
(
call_ctx_
,
arg_name
,
index
);
}
DataType
*
OutputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputDType
(
call_ctx_
,
arg_name
,
index
);
}
DataType
*
Dtype4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Dtype4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
bool
InputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
}
bool
*
OutputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
}
bool
*
IsDynamic4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
IsDynamic4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
JobDesc
*
job_desc
()
const
override
{
return
helper_
->
job_desc
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
int64_t
parallel_num
()
const
override
{
return
helper_
->
parallel_num
(
call_ctx_
);
}
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
input
(
arg_name
,
index
);
}
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
output
(
arg_name
,
index
);
}
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
has_input
(
arg_name
,
index
);
}
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
has_output
(
arg_name
,
index
);
}
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
override
{
return
helper_
->
input_size
(
arg_name
);
}
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
override
{
return
helper_
->
output_size
(
arg_name
);
}
const
std
::
string
&
op_name
()
const
override
{
return
helper_
->
op_name
();
}
const
std
::
string
&
op_type_name
()
const
override
{
return
helper_
->
op_type_name
();
}
const
std
::
string
&
op_loc
()
const
override
{
return
helper_
->
op_loc
();
}
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
UserOpInferContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
};
class
UserKernelComputeContextHelper
final
{
public:
UserKernelComputeContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelComputeContextHelper
()
=
default
;
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
Tensor4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
CHECK
(
device_ctx
);
return
device_ctx
->
stream
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
class
UserKernelComputeContext
final
:
public
user_op
::
KernelComputeContext
{
public:
UserKernelComputeContext
(
const
UserKernelComputeContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
~
UserKernelComputeContext
()
=
default
;
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Tensor4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
private:
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
UserKernelComputeContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
DeviceCtx
*
device_ctx_
;
};
class
UserKernelRegContextHelper
final
{
public:
UserKernelRegContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelRegContextHelper
()
=
default
;
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
class
UserKernelRegContext
final
:
public
user_op
::
KernelRegContext
{
public:
UserKernelRegContext
(
const
UserKernelRegContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
~
UserKernelRegContext
()
=
default
;
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
}
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
UserKernelRegContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
};
class
UserKernelInitAndCacheContextHelper
final
{
public:
UserKernelInitAndCacheContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelInitAndCacheContextHelper
()
=
default
;
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
CHECK
(
device_ctx
);
return
device_ctx
->
stream
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
return
nd_sbp
.
sbp_parallel
(
0
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
->
nd_sbp
();
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
return
*
CHECK_JUST
(
base_ctx_helper_
.
parallel_desc
(
call_ctx
));
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
class
UserKernelInitAndCacheContext
final
:
public
user_op
::
KernelInitContext
,
public
user_op
::
KernelCacheContext
{
public:
UserKernelInitAndCacheContext
(
const
UserKernelInitAndCacheContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
~
UserKernelInitAndCacheContext
()
override
=
default
;
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
}
const
UserKernelInitAndCacheContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
DeviceCtx
*
device_ctx_
;
};
namespace
{
Maybe
<
void
>
InitTensorTupleIndexes4Bns
(
const
std
::
shared_ptr
<
const
OperatorConf
>&
op_conf
,
const
ArgVec
&
indexed_input_pairs
,
const
ArgVec
&
indexed_output_pairs
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4const_ibns
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4mut_ibns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut_obns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut2_obns
)
{
const
auto
*
op_reg_val
=
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
op_conf
->
user_conf
().
op_type_name
());
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
ArgModifierSignature
arg_modifier_signature
;
for
(
const
auto
&
pair
:
indexed_input_pairs
)
{
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
()
->
insert
(
{
ibn
,
user_op
::
InputArgModifier
()});
}
for
(
const
auto
&
pair
:
indexed_output_pairs
)
{
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
arg_modifier_signature
.
mutable_obn2output_blob_modifier
()
->
insert
(
{
obn
,
user_op
::
OutputArgModifier
()});
}
user_op
::
UserOpConfWrapper
op_conf_wrapper
(
op_conf
);
if
(
op_reg_val
->
input_arg_modify_fn
)
{
user_op
::
GetInputArgModifier
GetInputArgModifierFn
=
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
int32_t
in_arg_index
)
->
user_op
::
InputArgModifier
*
{
const
std
::
string
ibn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
auto
*
map
=
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
();
return
&
map
->
at
(
ibn
);
};
JUST
(
op_reg_val
->
input_arg_modify_fn
(
GetInputArgModifierFn
,
op_conf_wrapper
));
}
if
(
op_reg_val
->
output_arg_modify_fn
)
{
user_op
::
GetOutputArgModifier
GetOutputArgModifierFn
=
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
int32_t
in_arg_index
)
->
user_op
::
OutputArgModifier
*
{
const
std
::
string
obn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
auto
*
map
=
arg_modifier_signature
.
mutable_obn2output_blob_modifier
();
return
&
map
->
at
(
obn
);
};
JUST
(
op_reg_val
->
output_arg_modify_fn
(
GetOutputArgModifierFn
,
op_conf_wrapper
));
}
for
(
int
i
=
0
;
i
<
indexed_input_pairs
.
size
();
i
++
)
{
const
auto
&
pair
=
indexed_input_pairs
.
at
(
i
);
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
if
(
arg_modifier_signature
.
ibn2input_blob_modifier
().
at
(
ibn
).
is_mutable
())
{
input_tuple_indexes4mut_ibns
->
emplace_back
(
i
);
}
else
{
input_tuple_indexes4const_ibns
->
emplace_back
(
i
);
}
}
for
(
int
i
=
0
;
i
<
indexed_output_pairs
.
size
();
i
++
)
{
const
auto
&
pair
=
indexed_output_pairs
.
at
(
i
);
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
if
(
arg_modifier_signature
.
obn2output_blob_modifier
().
at
(
obn
).
header_infered_before_compute
())
{
output_tuple_indexes4mut_obns
->
emplace_back
(
i
);
}
else
{
output_tuple_indexes4mut2_obns
->
emplace_back
(
i
);
}
}
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
/* static */
Maybe
<
StatefulOpKernel
>
StatefulOpKernel
::
New
(
const
std
::
shared_ptr
<
OperatorConf
>&
op_conf
,
const
Symbol
<
Stream
>&
stream
,
const
AttrMap
&
base_attrs
,
const
std
::
shared_ptr
<
const
ParallelDesc
>&
parallel_desc
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
{
auto
opkernel
=
std
::
shared_ptr
<
StatefulOpKernel
>
(
new
StatefulOpKernel
());
opkernel
->
base_attrs_
=
base_attrs
;
opkernel
->
op_conf_
=
op_conf
;
opkernel
->
user_op_conf_
.
reset
(
new
user_op
::
UserOpConfWrapper
(
op_conf
));
opkernel
->
stream_
=
stream
;
opkernel
->
input_arg_tuple_
=
input_arg_tuple
;
opkernel
->
output_arg_tuple_
=
output_arg_tuple
;
opkernel
->
need_check_mem_case_
=
true
;
const
DeviceType
device_type
=
CHECK_JUST
(
DeviceType4DeviceTag
(
op_conf
->
device_tag
()));
const
user_op
::
UserOpConfWrapper
*
user_op_conf
=
opkernel
->
user_op_conf_
.
get
();
opkernel
->
op_infer_ctx_helper_
.
reset
(
new
UserOpInferContextHelper
(
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
opkernel
->
init_and_cache_ctx_helper_
.
reset
(
new
UserKernelInitAndCacheContextHelper
(
device_type
,
opkernel
->
user_op_conf_
.
get
(),
opkernel
->
input_arg_tuple_
,
opkernel
->
output_arg_tuple_
));
opkernel
->
compute_ctx_helper_
.
reset
(
new
UserKernelComputeContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
opkernel
->
reg_ctx_helper_
.
reset
(
new
UserKernelRegContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
const
auto
*
op_reg_val
=
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
user_op_conf
->
op_type_name
());
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
if
(
op_reg_val
->
logical_tensor_desc_infer_fn
)
{
opkernel
->
tensor_desc_infer_fn_
=
op_reg_val
->
logical_tensor_desc_infer_fn
;
}
else
{
return
Error
::
UnimplementedError
();
}
opkernel
->
data_type_infer_fn_
=
op_reg_val
->
data_type_infer_fn
;
JUST
(
InitTensorTupleIndexes4Bns
(
op_conf
,
input_arg_tuple
->
indexed_arg_name_and_index
(),
output_arg_tuple
->
indexed_arg_name_and_index
(),
&
opkernel
->
input_tuple_indexes4const_ibns_
,
&
opkernel
->
input_tuple_indexes4mut_ibns_
,
&
opkernel
->
output_tuple_indexes4mut_obns_
,
&
opkernel
->
output_tuple_indexes4mut2_obns_
));
return
opkernel
;
}
StatefulOpKernel
::~
StatefulOpKernel
()
=
default
;
size_t
StatefulOpKernel
::
InferTmpSize
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
OpKernel
*
user_opkernel
)
const
{
UserOpInferContext
op_infer_ctx
(
op_infer_ctx_helper_
.
get
(),
call_ctx
);
const
auto
&
InferTmpSizeFn
=
GetInferTmpSizeFn
(
user_opkernel
);
return
InferTmpSizeFn
(
&
op_infer_ctx
);
}
Maybe
<
void
>
StatefulOpKernel
::
ChooseOpKernel
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
OpKernel
**
user_opkernel
,
bool
*
need_temp_storage
)
{
OF_PROFILER_RANGE_GUARD
(
"ChooseOpKernel"
);
DataType
primary_dtype
=
kInvalidDataType
;
const
auto
&
inputs
=
call_ctx
->
inputs
();
const
auto
&
outputs
=
call_ctx
->
outputs
();
if
(
likely
(
!
inputs
->
empty
()))
{
primary_dtype
=
(
*
inputs
)[
0
]
->
data_type
();
}
else
if
(
likely
(
!
outputs
->
empty
()))
{
primary_dtype
=
(
*
outputs
)[
0
]
->
data_type
();
}
else
{
// do nothing
}
UserKernelRegContext
reg_ctx
(
reg_ctx_helper_
.
get
(),
call_ctx
);
for
(
const
auto
&
pair
:
dtype2cached_kernels_
[
primary_dtype
])
{
if
(
likely
(
pair
.
first
->
is_matched_hob
->
get
(
reg_ctx
)))
{
*
need_temp_storage
=
pair
.
first
->
need_temp_storage
;
*
user_opkernel
=
pair
.
second
.
get
();
return
Maybe
<
void
>::
Ok
();
}
}
OF_PROFILER_RANGE_GUARD
(
"fallback"
);
const
auto
&
op_type_name
=
user_op_conf_
->
op_type_name
();
const
auto
*
kernel_reg_val
=
JUST
(
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpKernelRegistryResult
(
op_type_name
,
reg_ctx
));
CHECK_NOTNULL
(
kernel_reg_val
);
auto
*
kernel
=
kernel_reg_val
->
create_fn
();
dtype2cached_kernels_
[
primary_dtype
].
push_back
(
{
kernel_reg_val
,
std
::
shared_ptr
<
const
user_op
::
OpKernel
>
(
kernel
)});
infer_tmp_size_fn_map_
.
emplace
(
kernel
,
&
kernel_reg_val
->
infer_tmp_size_fn
);
*
need_temp_storage
=
kernel_reg_val
->
need_temp_storage
;
*
user_opkernel
=
kernel
;
return
Maybe
<
void
>::
Ok
();
}
void
StatefulOpKernel
::
TryInitOpKernelStateAndCache
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
const
user_op
::
OpKernel
*
op_kernel
,
user_op
::
OpKernelState
**
state
,
user_op
::
OpKernelCache
**
cache
)
{
UserKernelInitAndCacheContext
init_and_cache_ctx
(
init_and_cache_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
if
(
state
!=
nullptr
)
{
auto
it
=
op_kernel_state_map_
.
find
(
op_kernel
);
if
(
it
!=
op_kernel_state_map_
.
end
())
{
*
state
=
it
->
second
.
get
();
}
else
{
auto
created_state
=
op_kernel
->
CreateOpKernelState
(
&
init_and_cache_ctx
);
op_kernel_state_map_
.
emplace
(
op_kernel
,
created_state
);
*
state
=
created_state
.
get
();
}
}
{
auto
&
cache_in_map
=
op_kernel_cache_map_
[
op_kernel
];
op_kernel
->
InitOpKernelCacheWithFlags
(
&
init_and_cache_ctx
,
user_op
::
OpKernelCache
::
kAllMayChanged
,
&
cache_in_map
);
*
cache
=
cache_in_map
.
get
();
}
}
const
user_op
::
InferTmpSizeFn
&
StatefulOpKernel
::
GetInferTmpSizeFn
(
const
user_op
::
OpKernel
*
op_kernel
)
const
{
return
*
infer_tmp_size_fn_map_
.
at
(
op_kernel
);
}
user_op
::
TensorDescInferFn
StatefulOpKernel
::
TensorDescInferFn
()
const
{
return
tensor_desc_infer_fn_
;
}
user_op
::
DataTypeInferFn
StatefulOpKernel
::
DataTypeInferFn
()
const
{
return
data_type_infer_fn_
;
}
void
StatefulOpKernel
::
Compute
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
const
user_op
::
OpKernel
*
user_opkernel
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
cache
)
const
{
UserKernelComputeContext
compute_context
(
compute_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
auto
*
compute_ctx
=
&
compute_context
;
OF_PROFILER_RANGE_GUARD
(
"Compute"
);
if
(
Singleton
<
profiler
::
ProfileManager
>::
Get
())
{
#if defined(WITH_CUDA)
const
auto
CalMemorySize
=
[
compute_ctx
](
const
one
::
ArgVec
&
args
)
->
int64_t
{
const
auto
Func
=
[
compute_ctx
](
int64_t
mem_size
,
const
auto
&
pair
)
{
const
auto
tensor
=
compute_ctx
->
Tensor4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
);
return
mem_size
+
tensor
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
tensor
->
data_type
());
};
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
static_cast
<
int64_t
>
(
0
),
Func
);
};
#endif
auto
er_guard
=
CHECK_JUST
(
profiler
::
EventRecorder
::
CreateKernelEventRecorder
(
op_type_name
(),
#if defined(WITH_CUDA)
[
compute_ctx
,
CalMemorySize
]()
->
int64_t
{
return
CalMemorySize
(
compute_ctx
->
inputs
())
+
CalMemorySize
(
compute_ctx
->
outputs
());
},
#endif
[
compute_ctx
]()
->
std
::
vector
<
Shape
View
>
{
std
::
vector
<
Shape
View
>
shapes
;
for
(
const
auto
&
pair
:
compute_ctx
->
inputs
())
{
shapes
.
emplace_back
(
compute_ctx
->
TensorDesc4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
)
->
shape
());
}
return
shapes
;
}));
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
}
else
{
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
}
}
}
// namespace one
}
// namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/eager/call_context.h"
namespace
oneflow
{
namespace
one
{
class
ConsistentTensorInferResult
;
using
ArgVec
=
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>
;
using
EagerBlobObjectListRawPtr
=
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>*
;
using
ConsistentTensorInferResultRawPtr
=
const
ConsistentTensorInferResult
*
;
class
ZeroCopyBaseContextHelper
{
public:
ZeroCopyBaseContextHelper
(
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
input_arg_tuple_
(
input_arg_tuple
),
output_arg_tuple_
(
output_arg_tuple
)
{}
#define RETURN_IF_FOUND(inputs, outputs, post_action) \
int32_t
i
=
TryGetTensorTupleIndex
(
input_arg_tuple_
->
arg_name2bn_index2tensor_tuple_index
(),
\
arg_name
,
index
);
\
if
(
i
>=
0
)
{
return
(
inputs
).
at
(
i
)
post_action
;
}
\
i
=
TryGetTensorTupleIndex
(
output_arg_tuple_
->
arg_name2bn_index2tensor_tuple_index
(),
arg_name
,
\
index
);
\
if
(
i
>=
0
)
{
return
(
outputs
).
at
(
i
)
post_action
;
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
return
nullptr
;
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
if
(
arg_name
==
"tmp_buffer"
&&
index
==
0
)
{
return
call_ctx
->
mut_tmp_tensor
();
}
return
nullptr
;
}
const
ConsistentTensorMeta
*
ConsistentTensorMeta4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
RETURN_IF_FOUND
(
consistent_tensor_infer_result
->
input_tensor_metas
(),
consistent_tensor_infer_result
->
output_tensor_metas
(),
.
shared_from_symbol
().
get
());
return
nullptr
;
}
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
if
(
!
consistent_tensor_infer_result
)
{
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
if
(
!
consistent_tensor_infer_result
->
input_tensor_metas
().
empty
())
{
return
consistent_tensor_infer_result
->
input_tensor_metas
().
at
(
0
)
->
parallel_desc
();
}
else
if
(
!
consistent_tensor_infer_result
->
output_tensor_metas
().
empty
())
{
return
consistent_tensor_infer_result
->
output_tensor_metas
().
at
(
0
)
->
parallel_desc
();
}
else
{
UNIMPLEMENTED
();
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
auto
&
parallel_desc
=
this
->
parallel_desc
(
call_ctx
);
if
(
parallel_desc
.
has_value
())
{
const
auto
&
parallel_desc_symbol
=
CHECK_JUST
(
parallel_desc
);
return
*
CHECK_JUST
(
GetParallelContext4CurrentProcessCtx
(
parallel_desc_symbol
));
}
else
{
static
ParallelContext
single_device_parallel_ctx
(
MakeSingleDeviceParallelCtx
());
return
single_device_parallel_ctx
;
}
}
const
ArgVec
&
inputs
()
const
{
return
input_arg_tuple_
->
indexed_arg_name_and_index
();
}
const
ArgVec
&
outputs
()
const
{
return
output_arg_tuple_
->
indexed_arg_name_and_index
();
}
private:
static
int32_t
TryGetTensorTupleIndex
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int32_t
>>&
arg_name2bn_index2tensor_tuple_index
,
const
std
::
string
&
arg_name
,
const
int32_t
arg_index
)
{
auto
it
=
arg_name2bn_index2tensor_tuple_index
.
find
(
arg_name
);
if
(
it
!=
arg_name2bn_index2tensor_tuple_index
.
end
())
{
return
it
->
second
.
at
(
arg_index
);
}
return
-
1
;
}
static
ParallelContext
MakeSingleDeviceParallelCtx
()
{
ParallelContext
single_device_parallel_ctx
;
single_device_parallel_ctx
.
set_parallel_id
(
0
);
single_device_parallel_ctx
.
set_parallel_num
(
1
);
return
single_device_parallel_ctx
;
}
std
::
shared_ptr
<
const
ArgTuple
>
input_arg_tuple_
;
std
::
shared_ptr
<
const
ArgTuple
>
output_arg_tuple_
;
};
class
UserKernelBaseContextHelper
final
:
public
ZeroCopyBaseContextHelper
{
public:
UserKernelBaseContextHelper
(
DeviceType
device_type
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
ZeroCopyBaseContextHelper
(
input_arg_tuple
,
output_arg_tuple
),
device_type_
(
device_type
)
{}
~
UserKernelBaseContextHelper
()
=
default
;
DeviceType
device_type
()
const
{
return
device_type_
;
}
const
JobDesc
&
job_desc
()
const
{
UNIMPLEMENTED
();
return
*
(
const
JobDesc
*
)
nullptr
;
}
private:
const
DeviceType
device_type_
;
};
class
UserOpInferContextHelper
final
{
public:
UserOpInferContextHelper
(
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
zero_copy_base_ctx_helper_
(
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserOpInferContextHelper
()
=
default
;
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
UNIMPLEMENTED
();
return
nullptr
;
}
const
user_op
::
TensorDesc
&
InputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
));
}
user_op
::
TensorDesc
*
OutputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
zero_copy_base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
Shape
&
InputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Shape
*
OutputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Shape
*
Shape4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_shape
();
}
const
Stride
&
InputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Stride
*
OutputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
Stride
*
Stride4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_stride
();
}
const
DataType
&
InputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
DataType
*
OutputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
DataType
*
Dtype4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_data_type
();
}
bool
InputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
bool
*
OutputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
bool
*
IsDynamic4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_is_dynamic
();
}
const
ArgVec
&
inputs
()
const
{
return
zero_copy_base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
zero_copy_base_ctx_helper_
.
outputs
();
}
const
JobDesc
*
job_desc
()
const
{
UNIMPLEMENTED
();
return
nullptr
;
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
zero_copy_base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
return
*
CHECK_JUST
(
zero_copy_base_ctx_helper_
.
parallel_desc
(
call_ctx
));
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
return
nd_sbp
.
sbp_parallel
(
0
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
zero_copy_base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
->
nd_sbp
();
}
int64_t
parallel_num
(
eager
::
CallContext
*
call_ctx
)
const
{
return
parallel_ctx
(
call_ctx
).
parallel_num
();
}
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
input
(
arg_name
,
index
);
}
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
output
(
arg_name
,
index
);
}
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
has_input
(
arg_name
,
index
);
}
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
has_output
(
arg_name
,
index
);
}
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
{
return
user_op_conf
().
input_size
(
arg_name
);
}
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
{
return
user_op_conf
().
output_size
(
arg_name
);
}
const
std
::
string
&
op_name
()
const
{
return
user_op_conf
().
op_name
();
}
const
std
::
string
&
op_type_name
()
const
{
return
user_op_conf
().
op_type_name
();
}
const
std
::
string
&
op_loc
()
const
{
return
user_op_conf_
->
op_conf
().
loc
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
private:
user_op
::
TensorDesc
*
NonNullTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
user_op
::
TensorDesc
*
tensor_desc
=
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
if
(
!
tensor_desc
)
{
LOG
(
FATAL
)
<<
"Arg ("
<<
arg_name
<<
","
<<
index
<<
") is not found"
;
}
return
tensor_desc
;
}
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
ZeroCopyBaseContextHelper
zero_copy_base_ctx_helper_
;
};
class
UserOpInferContext
:
public
user_op
::
InferContext
{
public:
UserOpInferContext
(
const
UserOpInferContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
~
UserOpInferContext
()
override
=
default
;
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
user_op
::
TensorDesc
&
InputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
}
user_op
::
TensorDesc
*
OutputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
Shape
&
InputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputShape
(
call_ctx_
,
arg_name
,
index
);
}
Shape
*
OutputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputShape
(
call_ctx_
,
arg_name
,
index
);
}
Shape
*
Shape4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Shape4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
Stride
&
InputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputStride
(
call_ctx_
,
arg_name
,
index
);
}
Stride
*
OutputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputStride
(
call_ctx_
,
arg_name
,
index
);
}
Stride
*
Stride4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Stride4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
DataType
&
InputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputDType
(
call_ctx_
,
arg_name
,
index
);
}
DataType
*
OutputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputDType
(
call_ctx_
,
arg_name
,
index
);
}
DataType
*
Dtype4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Dtype4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
bool
InputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
}
bool
*
OutputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
}
bool
*
IsDynamic4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
IsDynamic4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
JobDesc
*
job_desc
()
const
override
{
return
helper_
->
job_desc
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
int64_t
parallel_num
()
const
override
{
return
helper_
->
parallel_num
(
call_ctx_
);
}
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
input
(
arg_name
,
index
);
}
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
output
(
arg_name
,
index
);
}
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
has_input
(
arg_name
,
index
);
}
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
has_output
(
arg_name
,
index
);
}
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
override
{
return
helper_
->
input_size
(
arg_name
);
}
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
override
{
return
helper_
->
output_size
(
arg_name
);
}
const
std
::
string
&
op_name
()
const
override
{
return
helper_
->
op_name
();
}
const
std
::
string
&
op_type_name
()
const
override
{
return
helper_
->
op_type_name
();
}
const
std
::
string
&
op_loc
()
const
override
{
return
helper_
->
op_loc
();
}
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
UserOpInferContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
};
class
UserKernelComputeContextHelper
final
{
public:
UserKernelComputeContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelComputeContextHelper
()
=
default
;
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
Tensor4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
CHECK
(
device_ctx
);
return
device_ctx
->
stream
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
class
UserKernelComputeContext
final
:
public
user_op
::
KernelComputeContext
{
public:
UserKernelComputeContext
(
const
UserKernelComputeContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
~
UserKernelComputeContext
()
=
default
;
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Tensor4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
private:
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
UserKernelComputeContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
DeviceCtx
*
device_ctx_
;
};
class
UserKernelRegContextHelper
final
{
public:
UserKernelRegContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelRegContextHelper
()
=
default
;
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
class
UserKernelRegContext
final
:
public
user_op
::
KernelRegContext
{
public:
UserKernelRegContext
(
const
UserKernelRegContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
~
UserKernelRegContext
()
=
default
;
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
}
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
UserKernelRegContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
};
class
UserKernelInitAndCacheContextHelper
final
{
public:
UserKernelInitAndCacheContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelInitAndCacheContextHelper
()
=
default
;
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
CHECK
(
device_ctx
);
return
device_ctx
->
stream
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
return
nd_sbp
.
sbp_parallel
(
0
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
->
nd_sbp
();
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
return
*
CHECK_JUST
(
base_ctx_helper_
.
parallel_desc
(
call_ctx
));
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
class
UserKernelInitAndCacheContext
final
:
public
user_op
::
KernelInitContext
,
public
user_op
::
KernelCacheContext
{
public:
UserKernelInitAndCacheContext
(
const
UserKernelInitAndCacheContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
~
UserKernelInitAndCacheContext
()
override
=
default
;
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
}
const
UserKernelInitAndCacheContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
DeviceCtx
*
device_ctx_
;
};
namespace
{
Maybe
<
void
>
InitTensorTupleIndexes4Bns
(
const
std
::
shared_ptr
<
const
OperatorConf
>&
op_conf
,
const
ArgVec
&
indexed_input_pairs
,
const
ArgVec
&
indexed_output_pairs
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4const_ibns
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4mut_ibns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut_obns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut2_obns
)
{
const
auto
*
op_reg_val
=
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
op_conf
->
user_conf
().
op_type_name
());
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
ArgModifierSignature
arg_modifier_signature
;
for
(
const
auto
&
pair
:
indexed_input_pairs
)
{
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
()
->
insert
(
{
ibn
,
user_op
::
InputArgModifier
()});
}
for
(
const
auto
&
pair
:
indexed_output_pairs
)
{
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
arg_modifier_signature
.
mutable_obn2output_blob_modifier
()
->
insert
(
{
obn
,
user_op
::
OutputArgModifier
()});
}
user_op
::
UserOpConfWrapper
op_conf_wrapper
(
op_conf
);
if
(
op_reg_val
->
input_arg_modify_fn
)
{
user_op
::
GetInputArgModifier
GetInputArgModifierFn
=
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
int32_t
in_arg_index
)
->
user_op
::
InputArgModifier
*
{
const
std
::
string
ibn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
auto
*
map
=
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
();
return
&
map
->
at
(
ibn
);
};
JUST
(
op_reg_val
->
input_arg_modify_fn
(
GetInputArgModifierFn
,
op_conf_wrapper
));
}
if
(
op_reg_val
->
output_arg_modify_fn
)
{
user_op
::
GetOutputArgModifier
GetOutputArgModifierFn
=
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
int32_t
in_arg_index
)
->
user_op
::
OutputArgModifier
*
{
const
std
::
string
obn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
auto
*
map
=
arg_modifier_signature
.
mutable_obn2output_blob_modifier
();
return
&
map
->
at
(
obn
);
};
JUST
(
op_reg_val
->
output_arg_modify_fn
(
GetOutputArgModifierFn
,
op_conf_wrapper
));
}
for
(
int
i
=
0
;
i
<
indexed_input_pairs
.
size
();
i
++
)
{
const
auto
&
pair
=
indexed_input_pairs
.
at
(
i
);
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
if
(
arg_modifier_signature
.
ibn2input_blob_modifier
().
at
(
ibn
).
is_mutable
())
{
input_tuple_indexes4mut_ibns
->
emplace_back
(
i
);
}
else
{
input_tuple_indexes4const_ibns
->
emplace_back
(
i
);
}
}
for
(
int
i
=
0
;
i
<
indexed_output_pairs
.
size
();
i
++
)
{
const
auto
&
pair
=
indexed_output_pairs
.
at
(
i
);
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
if
(
arg_modifier_signature
.
obn2output_blob_modifier
().
at
(
obn
).
header_infered_before_compute
())
{
output_tuple_indexes4mut_obns
->
emplace_back
(
i
);
}
else
{
output_tuple_indexes4mut2_obns
->
emplace_back
(
i
);
}
}
return
Maybe
<
void
>::
Ok
();
}
}
// namespace
/* static */
Maybe
<
StatefulOpKernel
>
StatefulOpKernel
::
New
(
const
std
::
shared_ptr
<
OperatorConf
>&
op_conf
,
const
Symbol
<
Stream
>&
stream
,
const
AttrMap
&
base_attrs
,
const
std
::
shared_ptr
<
const
ParallelDesc
>&
parallel_desc
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
{
auto
opkernel
=
std
::
shared_ptr
<
StatefulOpKernel
>
(
new
StatefulOpKernel
());
opkernel
->
base_attrs_
=
base_attrs
;
opkernel
->
op_conf_
=
op_conf
;
opkernel
->
user_op_conf_
.
reset
(
new
user_op
::
UserOpConfWrapper
(
op_conf
));
opkernel
->
stream_
=
stream
;
opkernel
->
input_arg_tuple_
=
input_arg_tuple
;
opkernel
->
output_arg_tuple_
=
output_arg_tuple
;
opkernel
->
need_check_mem_case_
=
true
;
const
DeviceType
device_type
=
CHECK_JUST
(
DeviceType4DeviceTag
(
op_conf
->
device_tag
()));
const
user_op
::
UserOpConfWrapper
*
user_op_conf
=
opkernel
->
user_op_conf_
.
get
();
opkernel
->
op_infer_ctx_helper_
.
reset
(
new
UserOpInferContextHelper
(
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
opkernel
->
init_and_cache_ctx_helper_
.
reset
(
new
UserKernelInitAndCacheContextHelper
(
device_type
,
opkernel
->
user_op_conf_
.
get
(),
opkernel
->
input_arg_tuple_
,
opkernel
->
output_arg_tuple_
));
opkernel
->
compute_ctx_helper_
.
reset
(
new
UserKernelComputeContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
opkernel
->
reg_ctx_helper_
.
reset
(
new
UserKernelRegContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
const
auto
*
op_reg_val
=
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
user_op_conf
->
op_type_name
());
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
if
(
op_reg_val
->
logical_tensor_desc_infer_fn
)
{
opkernel
->
tensor_desc_infer_fn_
=
op_reg_val
->
logical_tensor_desc_infer_fn
;
}
else
{
return
Error
::
UnimplementedError
();
}
opkernel
->
data_type_infer_fn_
=
op_reg_val
->
data_type_infer_fn
;
JUST
(
InitTensorTupleIndexes4Bns
(
op_conf
,
input_arg_tuple
->
indexed_arg_name_and_index
(),
output_arg_tuple
->
indexed_arg_name_and_index
(),
&
opkernel
->
input_tuple_indexes4const_ibns_
,
&
opkernel
->
input_tuple_indexes4mut_ibns_
,
&
opkernel
->
output_tuple_indexes4mut_obns_
,
&
opkernel
->
output_tuple_indexes4mut2_obns_
));
return
opkernel
;
}
StatefulOpKernel
::~
StatefulOpKernel
()
=
default
;
size_t
StatefulOpKernel
::
InferTmpSize
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
OpKernel
*
user_opkernel
)
const
{
UserOpInferContext
op_infer_ctx
(
op_infer_ctx_helper_
.
get
(),
call_ctx
);
const
auto
&
InferTmpSizeFn
=
GetInferTmpSizeFn
(
user_opkernel
);
return
InferTmpSizeFn
(
&
op_infer_ctx
);
}
Maybe
<
void
>
StatefulOpKernel
::
ChooseOpKernel
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
OpKernel
**
user_opkernel
,
bool
*
need_temp_storage
)
{
OF_PROFILER_RANGE_GUARD
(
"ChooseOpKernel"
);
DataType
primary_dtype
=
kInvalidDataType
;
const
auto
&
inputs
=
call_ctx
->
inputs
();
const
auto
&
outputs
=
call_ctx
->
outputs
();
if
(
likely
(
!
inputs
->
empty
()))
{
primary_dtype
=
(
*
inputs
)[
0
]
->
data_type
();
}
else
if
(
likely
(
!
outputs
->
empty
()))
{
primary_dtype
=
(
*
outputs
)[
0
]
->
data_type
();
}
else
{
// do nothing
}
UserKernelRegContext
reg_ctx
(
reg_ctx_helper_
.
get
(),
call_ctx
);
for
(
const
auto
&
pair
:
dtype2cached_kernels_
[
primary_dtype
])
{
if
(
likely
(
pair
.
first
->
is_matched_hob
->
get
(
reg_ctx
)))
{
*
need_temp_storage
=
pair
.
first
->
need_temp_storage
;
*
user_opkernel
=
pair
.
second
.
get
();
return
Maybe
<
void
>::
Ok
();
}
}
OF_PROFILER_RANGE_GUARD
(
"fallback"
);
const
auto
&
op_type_name
=
user_op_conf_
->
op_type_name
();
const
auto
*
kernel_reg_val
=
JUST
(
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpKernelRegistryResult
(
op_type_name
,
reg_ctx
));
CHECK_NOTNULL
(
kernel_reg_val
);
auto
*
kernel
=
kernel_reg_val
->
create_fn
();
dtype2cached_kernels_
[
primary_dtype
].
push_back
(
{
kernel_reg_val
,
std
::
shared_ptr
<
const
user_op
::
OpKernel
>
(
kernel
)});
infer_tmp_size_fn_map_
.
emplace
(
kernel
,
&
kernel_reg_val
->
infer_tmp_size_fn
);
*
need_temp_storage
=
kernel_reg_val
->
need_temp_storage
;
*
user_opkernel
=
kernel
;
return
Maybe
<
void
>::
Ok
();
}
void
StatefulOpKernel
::
TryInitOpKernelStateAndCache
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
const
user_op
::
OpKernel
*
op_kernel
,
user_op
::
OpKernelState
**
state
,
user_op
::
OpKernelCache
**
cache
)
{
UserKernelInitAndCacheContext
init_and_cache_ctx
(
init_and_cache_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
if
(
state
!=
nullptr
)
{
auto
it
=
op_kernel_state_map_
.
find
(
op_kernel
);
if
(
it
!=
op_kernel_state_map_
.
end
())
{
*
state
=
it
->
second
.
get
();
}
else
{
auto
created_state
=
op_kernel
->
CreateOpKernelState
(
&
init_and_cache_ctx
);
op_kernel_state_map_
.
emplace
(
op_kernel
,
created_state
);
*
state
=
created_state
.
get
();
}
}
{
auto
&
cache_in_map
=
op_kernel_cache_map_
[
op_kernel
];
op_kernel
->
InitOpKernelCacheWithFlags
(
&
init_and_cache_ctx
,
user_op
::
OpKernelCache
::
kAllMayChanged
,
&
cache_in_map
);
*
cache
=
cache_in_map
.
get
();
}
}
const
user_op
::
InferTmpSizeFn
&
StatefulOpKernel
::
GetInferTmpSizeFn
(
const
user_op
::
OpKernel
*
op_kernel
)
const
{
return
*
infer_tmp_size_fn_map_
.
at
(
op_kernel
);
}
user_op
::
TensorDescInferFn
StatefulOpKernel
::
TensorDescInferFn
()
const
{
return
tensor_desc_infer_fn_
;
}
user_op
::
DataTypeInferFn
StatefulOpKernel
::
DataTypeInferFn
()
const
{
return
data_type_infer_fn_
;
}
void
StatefulOpKernel
::
Compute
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
const
user_op
::
OpKernel
*
user_opkernel
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
cache
)
const
{
UserKernelComputeContext
compute_context
(
compute_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
auto
*
compute_ctx
=
&
compute_context
;
OF_PROFILER_RANGE_GUARD
(
"Compute"
);
if
(
Singleton
<
profiler
::
ProfileManager
>::
Get
())
{
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
const
auto
CalMemorySize
=
[
compute_ctx
](
const
one
::
ArgVec
&
args
)
->
int64_t
{
const
auto
Func
=
[
compute_ctx
](
int64_t
mem_size
,
const
auto
&
pair
)
{
const
auto
tensor
=
compute_ctx
->
Tensor4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
);
return
mem_size
+
tensor
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
tensor
->
data_type
());
};
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
static_cast
<
int64_t
>
(
0
),
Func
);
};
#endif
auto
er_guard
=
CHECK_JUST
(
profiler
::
EventRecorder
::
CreateKernelEventRecorder
(
op_type_name
(),
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
[
compute_ctx
,
CalMemorySize
]()
->
int64_t
{
return
CalMemorySize
(
compute_ctx
->
inputs
())
+
CalMemorySize
(
compute_ctx
->
outputs
());
},
#endif
[
compute_ctx
]()
->
std
::
vector
<
Shape
>
{
std
::
vector
<
Shape
>
shapes
;
for
(
const
auto
&
pair
:
compute_ctx
->
inputs
())
{
shapes
.
emplace_back
(
compute_ctx
->
TensorDesc4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
)
->
shape
());
}
return
shapes
;
}));
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
}
else
{
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
}
}
}
// namespace one
}
// namespace oneflow
python/oneflow/test/modules/fused_dot_feature_interaction.py
0 → 100644
View file @
f262efc9
import
numpy
as
np
import
oneflow
as
flow
def
fused_dot_feature_interaction
(
x
,
y
,
self_interaction
=
False
,
output_padding
=
0
,
output_concat
=
None
,
dtype
=
flow
.
float32
):
# (bs, es) = x.shape
(
bs
,
dims
,
es
)
=
y
.
shape
if
self_interaction
:
offset
=
1
else
:
offset
=
0
li
=
flow
.
tensor
([
i
for
i
in
range
(
dims
+
1
)
for
j
in
range
(
i
+
offset
)])
lj
=
flow
.
tensor
([
j
for
i
in
range
(
dims
+
1
)
for
j
in
range
(
i
+
offset
)])
T
=
flow
.
cat
(
[
flow
.
reshape
(
x
,
(
bs
,
1
,
es
)),
y
,
],
dim
=
1
,
)
Z
=
flow
.
matmul
(
T
,
T
,
transpose_b
=
True
)
# gather_nd not support half, so cast to float32
Z
=
flow
.
cast
(
Z
,
flow
.
float32
)
Zflat
=
Z
[:,
li
,
lj
]
Zflat
=
flow
.
cast
(
Zflat
,
dtype
)
if
output_concat
is
not
None
:
R
=
flow
.
cat
([
output_concat
,
Zflat
],
dim
=
1
)
else
:
R
=
Zflat
if
output_padding
!=
0
:
padding_tensor
=
flow
.
tensor
(
np
.
zeros
((
bs
,
output_padding
)).
astype
(
np
.
float32
),
device
=
"cuda"
,
requires_grad
=
False
,
)
R
=
flow
.
cat
([
R
,
padding_tensor
],
dim
=
1
)
return
R
python/oneflow/test/profiler/test_profile_lenet.py
View file @
f262efc9
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
os
import
unittest
import
oneflow.unittest
import
oneflow
as
flow
import
oneflow.nn
as
nn
import
oneflow.nn.functional
as
F
import
oneflow.profiler
from
oneflow.profiler.events
import
CustomEvent
,
KernelEvent
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
conv1
(
x
))
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
F
.
relu
(
self
.
conv2
(
out
))
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
out
=
F
.
relu
(
self
.
fc1
(
out
))
out
=
F
.
relu
(
self
.
fc2
(
out
))
out
=
self
.
fc3
(
out
)
return
out
def
get_event
(
events
,
name
:
str
,
input_shapes
:
str
=
"-"
):
for
item
in
events
:
if
isinstance
(
item
,
CustomEvent
):
if
item
.
name
==
name
:
return
item
if
isinstance
(
item
,
KernelEvent
):
if
item
.
name
==
name
and
item
.
input_shapes
==
input_shapes
:
return
item
return
None
def
_test_lenet
(
test_case
,
on_cuda
:
bool
,
record_shapes
:
bool
,
record_bandwidth_for_cuda
:
bool
=
False
,
):
x
=
flow
.
randn
(
2
,
3
,
32
,
32
)
lenet
=
LeNet
()
if
on_cuda
:
x
=
x
.
to
(
"cuda"
)
lenet
.
to
(
"cuda"
)
activities
=
[
oneflow
.
profiler
.
ProfilerActivity
.
CPU
]
if
on_cuda
:
activities
.
append
(
oneflow
.
profiler
.
ProfilerActivity
.
CUDA
)
with
oneflow
.
profiler
.
profile
(
activities
=
activities
,
record_shapes
=
record_shapes
,
record_bandwidth_for_cuda
=
record_bandwidth_for_cuda
,
)
as
prof
:
with
oneflow
.
profiler
.
record_function
(
"lenet_forward_total_time"
)
as
f
:
for
_
in
range
(
2
):
eager_res
=
lenet
(
x
)
with
oneflow
.
profiler
.
record_function
(
"lenet_backward_total_time"
)
as
f
:
eager_res
.
sum
().
backward
()
events
=
prof
.
key_averages
(
group_by_input_shape
=
True
)
conv_event
=
get_event
(
events
,
"conv2d"
,
"[(2,3,32,32), (6,3,5,5)]"
if
record_shapes
else
"-"
)
test_case
.
assertIsNotNone
(
conv_event
)
if
on_cuda
:
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time_total
,
0.0
)
else
:
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertEqual
(
conv_event
.
count
,
2
if
record_shapes
else
4
)
if
record_bandwidth_for_cuda
and
on_cuda
:
test_case
.
assertNotEqual
(
conv_event
.
bandwidth
,
-
1
)
relu_grad_event
=
get_event
(
events
,
"relu_grad"
,
"[(2,6,28,28), (2,6,28,28)]"
if
record_shapes
else
"-"
)
test_case
.
assertIsNotNone
(
relu_grad_event
)
if
on_cuda
:
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time_total
,
0.0
)
else
:
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertEqual
(
relu_grad_event
.
count
,
1
if
record_shapes
else
4
)
if
record_bandwidth_for_cuda
and
on_cuda
:
test_case
.
assertNotEqual
(
relu_grad_event
.
bandwidth
,
-
1
)
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_forward_total_time"
))
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_backward_total_time"
))
class
TestProfileLenet
(
flow
.
unittest
.
TestCase
):
def
test_lenet_cpu
(
test_case
):
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
True
)
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
False
)
@
unittest
.
skipIf
(
os
.
getenv
(
"ONEFLOW_TEST_CPU_ONLY"
),
"only test cpu cases"
)
def
test_lenet_cuda
(
test_case
):
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
False
)
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
False
,
)
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
True
)
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
os
import
unittest
import
oneflow.unittest
import
oneflow
as
flow
import
oneflow.nn
as
nn
import
oneflow.nn.functional
as
F
import
oneflow.profiler
from
oneflow.profiler.events
import
CustomEvent
,
KernelEvent
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
conv1
(
x
))
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
F
.
relu
(
self
.
conv2
(
out
))
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
out
=
F
.
relu
(
self
.
fc1
(
out
))
out
=
F
.
relu
(
self
.
fc2
(
out
))
out
=
self
.
fc3
(
out
)
return
out
def
get_event
(
events
,
name
:
str
,
input_shapes
:
str
=
"-"
):
for
item
in
events
:
if
isinstance
(
item
,
CustomEvent
):
if
item
.
name
==
name
:
return
item
if
isinstance
(
item
,
KernelEvent
):
if
item
.
name
==
name
and
item
.
input_shapes
==
input_shapes
:
return
item
return
None
def
_test_lenet
(
test_case
,
on_cuda
:
bool
,
record_shapes
:
bool
,
record_bandwidth_for_cuda
:
bool
=
False
,
):
x
=
flow
.
randn
(
2
,
3
,
32
,
32
)
lenet
=
LeNet
()
if
on_cuda
:
x
=
x
.
to
(
"cuda"
)
lenet
.
to
(
"cuda"
)
activities
=
[
oneflow
.
profiler
.
ProfilerActivity
.
CPU
]
if
on_cuda
:
activities
.
append
(
oneflow
.
profiler
.
ProfilerActivity
.
CUDA
)
with
oneflow
.
profiler
.
profile
(
activities
=
activities
,
record_shapes
=
record_shapes
,
record_bandwidth_for_cuda
=
record_bandwidth_for_cuda
,
)
as
prof
:
with
oneflow
.
profiler
.
record_function
(
"lenet_forward_total_time"
)
as
f
:
for
_
in
range
(
2
):
eager_res
=
lenet
(
x
)
with
oneflow
.
profiler
.
record_function
(
"lenet_backward_total_time"
)
as
f
:
eager_res
.
sum
().
backward
()
events
=
prof
.
key_averages
(
group_by_input_shape
=
True
)
print
(
events
)
conv_event
=
get_event
(
events
,
"conv2d"
,
"[(2,3,32,32), (6,3,5,5)]"
if
record_shapes
else
"-"
)
test_case
.
assertIsNotNone
(
conv_event
)
if
on_cuda
:
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time_total
,
0.0
)
else
:
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertEqual
(
conv_event
.
count
,
2
if
record_shapes
else
4
)
if
record_bandwidth_for_cuda
and
on_cuda
:
test_case
.
assertNotEqual
(
conv_event
.
bandwidth
,
-
1
)
relu_grad_event
=
get_event
(
events
,
"relu_grad"
,
"[(2,6,28,28), (2,6,28,28)]"
if
record_shapes
else
"-"
)
test_case
.
assertIsNotNone
(
relu_grad_event
)
if
on_cuda
:
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time_total
,
0.0
)
else
:
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertEqual
(
relu_grad_event
.
count
,
1
if
record_shapes
else
4
)
if
record_bandwidth_for_cuda
and
on_cuda
:
test_case
.
assertNotEqual
(
relu_grad_event
.
bandwidth
,
-
1
)
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_forward_total_time"
))
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_backward_total_time"
))
class
TestProfileLenet
(
flow
.
unittest
.
TestCase
):
def
test_lenet_cpu
(
test_case
):
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
True
)
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
False
)
@
unittest
.
skipIf
(
os
.
getenv
(
"ONEFLOW_TEST_CPU_ONLY"
),
"only test cpu cases"
)
def
test_lenet_cuda
(
test_case
):
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
False
)
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
False
,
)
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
True
)
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment