Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
f262efc9
Commit
f262efc9
authored
Nov 21, 2022
by
yuguo
Browse files
Surpport profiler for DCU, surpport debug compiler
parent
3f56062c
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2666 additions
and
2387 deletions
+2666
-2387
CMakeLists.txt
CMakeLists.txt
+2
-2
cmake/third_party.cmake
cmake/third_party.cmake
+2
-0
oneflow/core/profiler/event.cpp
oneflow/core/profiler/event.cpp
+90
-94
oneflow/core/profiler/event.h
oneflow/core/profiler/event.h
+186
-188
oneflow/core/profiler/event_recorder.cpp
oneflow/core/profiler/event_recorder.cpp
+2
-2
oneflow/core/profiler/event_recorder.h
oneflow/core/profiler/event_recorder.h
+60
-60
oneflow/core/profiler/kernel.cpp
oneflow/core/profiler/kernel.cpp
+64
-0
oneflow/core/profiler/kineto_shim.cpp
oneflow/core/profiler/kineto_shim.cpp
+1
-1
oneflow/core/profiler/kineto_shim.h
oneflow/core/profiler/kineto_shim.h
+1
-1
oneflow/core/profiler/profile_manager.cpp
oneflow/core/profiler/profile_manager.cpp
+5
-6
oneflow/core/profiler/profile_manager.h
oneflow/core/profiler/profile_manager.h
+1
-1
oneflow/core/profiler/profiler.cpp
oneflow/core/profiler/profiler.cpp
+39
-0
oneflow/user/kernels/math_unary_elementwise_func.h
oneflow/user/kernels/math_unary_elementwise_func.h
+983
-983
oneflow/user/kernels/nvtx_range_kernel.hip.cpp
oneflow/user/kernels/nvtx_range_kernel.hip.cpp
+138
-0
oneflow/user/kernels/stateful_opkernel.cpp
oneflow/user/kernels/stateful_opkernel.cpp
+901
-901
python/oneflow/test/modules/fused_dot_feature_interaction.py
python/oneflow/test/modules/fused_dot_feature_interaction.py
+43
-0
python/oneflow/test/profiler/test_profile_lenet.py
python/oneflow/test/profiler/test_profile_lenet.py
+148
-148
No files found.
CMakeLists.txt
View file @
f262efc9
...
@@ -265,9 +265,9 @@ set(ROBIN_HOOD_HASHING_URL
...
@@ -265,9 +265,9 @@ set(ROBIN_HOOD_HASHING_URL
use_mirror
(
VARIABLE ROBIN_HOOD_HASHING_URL URL
${
ROBIN_HOOD_HASHING_URL
}
)
use_mirror
(
VARIABLE ROBIN_HOOD_HASHING_URL URL
${
ROBIN_HOOD_HASHING_URL
}
)
set
(
ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467
)
set
(
ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467
)
set
(
FMT_URL https://github.com/fmtlib/fmt/archive/
48b7e3dafb27ece02cd6addc8bd1041c79d59c2c
.zip
)
set
(
FMT_URL https://github.com/fmtlib/fmt/archive/
fc07217d85e6dcec52878807d6bbd89a9d9156a5
.zip
)
use_mirror
(
VARIABLE FMT_URL URL
${
FMT_URL
}
)
use_mirror
(
VARIABLE FMT_URL URL
${
FMT_URL
}
)
set
(
FMT_MD5
45925a979ed7195e0c88a70be691de09
)
set
(
FMT_MD5
7d9bb2ececc9ede29cd35bdc42a7e22c
)
set
(
KINETO_URL
set
(
KINETO_URL
https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip
)
https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip
)
...
...
cmake/third_party.cmake
View file @
f262efc9
...
@@ -175,6 +175,8 @@ if (BUILD_ROCM)
...
@@ -175,6 +175,8 @@ if (BUILD_ROCM)
add_definitions
(
-D__HIP_PLATFORM_HCC__
)
add_definitions
(
-D__HIP_PLATFORM_HCC__
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
-mcmodel=large"
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
list
(
APPEND oneflow_third_party_libs hip::hipcub
)
list
(
APPEND oneflow_third_party_libs hip::hipcub
)
...
...
oneflow/core/profiler/event.cpp
View file @
f262efc9
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
// #include "fmt/core.h"
#include "fmt/core.h"
// #include "fmt/format.h"
#include "fmt/format.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/util.h"
#include "oneflow/core/profiler/util.h"
using
json
=
nlohmann
::
json
;
using
json
=
nlohmann
::
json
;
namespace
oneflow
{
namespace
oneflow
{
namespace
profiler
{
namespace
profiler
{
nlohmann
::
json
IEvent
::
ToJson
()
{
nlohmann
::
json
IEvent
::
ToJson
()
{
return
json
{{
"name"
,
name_
},
{
"time"
,
GetDuration
<
double
>
()},
{
"input_shapes"
,
"-"
}};
return
json
{{
"name"
,
name_
},
{
"time"
,
GetDuration
<
double
>
()},
{
"input_shapes"
,
"-"
}};
}
}
void
IEvent
::
SetStartedAt
(
double
t
)
{
started_at_
=
t
;
}
void
IEvent
::
SetStartedAt
(
double
t
)
{
started_at_
=
t
;
}
void
IEvent
::
SetFinishedAt
(
double
t
)
{
finished_at_
=
t
;
}
void
IEvent
::
SetFinishedAt
(
double
t
)
{
finished_at_
=
t
;
}
void
IEvent
::
Start
()
{
SetStartedAt
(
GetTimeNow
());
}
void
IEvent
::
Start
()
{
SetStartedAt
(
GetTimeNow
());
}
void
IEvent
::
Finish
()
{
SetFinishedAt
(
GetTimeNow
());
}
void
IEvent
::
Finish
()
{
SetFinishedAt
(
GetTimeNow
());
}
bool
IEvent
::
IsChildOf
(
const
IEvent
*
e
)
{
bool
IEvent
::
IsChildOf
(
const
IEvent
*
e
)
{
if
(
!
e
)
{
return
false
;
}
if
(
!
e
)
{
return
false
;
}
if
(
this
==
e
)
{
return
false
;
}
if
(
this
==
e
)
{
return
false
;
}
return
GetStartedAt
<
double
>
()
>=
e
->
GetStartedAt
<
double
>
()
return
GetStartedAt
<
double
>
()
>=
e
->
GetStartedAt
<
double
>
()
&&
GetFinishedAt
<
double
>
()
<=
e
->
GetFinishedAt
<
double
>
();
&&
GetFinishedAt
<
double
>
()
<=
e
->
GetFinishedAt
<
double
>
();
}
}
const
std
::
string
&
IEvent
::
GetName
()
const
{
return
name_
;
}
const
std
::
string
&
IEvent
::
GetName
()
const
{
return
name_
;
}
std
::
string
CustomEvent
::
Key
()
{
return
name_
;
}
std
::
string
CustomEvent
::
Key
()
{
return
name_
;
}
nlohmann
::
json
CustomEvent
::
ToJson
()
{
nlohmann
::
json
CustomEvent
::
ToJson
()
{
auto
j
=
IEvent
::
ToJson
();
auto
j
=
IEvent
::
ToJson
();
j
[
"type"
]
=
EventType
::
kCustom
;
j
[
"type"
]
=
EventType
::
kCustom
;
j
[
"custom_type"
]
=
type_
;
j
[
"custom_type"
]
=
type_
;
return
j
;
return
j
;
}
}
std
::
shared_ptr
<
CustomEvent
>
CustomEvent
::
Create
(
const
std
::
string
&
name
,
CustomEventType
type
)
{
std
::
shared_ptr
<
CustomEvent
>
CustomEvent
::
Create
(
const
std
::
string
&
name
,
CustomEventType
type
)
{
return
std
::
shared_ptr
<
CustomEvent
>
(
new
CustomEvent
(
name
,
type
));
return
std
::
shared_ptr
<
CustomEvent
>
(
new
CustomEvent
(
name
,
type
));
}
}
// std::string KernelEvent::Key() { return fmt::format("{}.{}", name_, GetFormatedInputShapes()); }
std
::
string
KernelEvent
::
Key
()
{
return
fmt
::
format
(
"{}.{}"
,
name_
,
GetFormatedInputShapes
());
}
std
::
string
KernelEvent
::
Key
()
{
return
"yuguo"
;
}
nlohmann
::
json
KernelEvent
::
ToJson
()
{
nlohmann
::
json
KernelEvent
::
ToJson
()
{
auto
j
=
IEvent
::
ToJson
();
auto
j
=
IEvent
::
ToJson
();
j
[
"type"
]
=
EventType
::
kOneflowKernel
;
j
[
"type"
]
=
EventType
::
kOneflowKernel
;
j
[
"input_shapes"
]
=
GetFormatedInputShapes
();
j
[
"input_shapes"
]
=
GetFormatedInputShapes
();
#if defined(WITH_CUDA) || defined(WITH_ROCM)
#if defined(WITH_CUDA)
j
[
"memory_size"
]
=
memory_size_
;
j
[
"memory_size"
]
=
memory_size_
;
if
(
!
children_
.
empty
())
{
j
[
"children"
]
=
children_
;
}
if
(
!
children_
.
empty
())
{
j
[
"children"
]
=
children_
;
}
#endif // WITH_CUDA
#endif // WITH_CUDA
return
j
;
return
j
;
}
}
std
::
shared_ptr
<
KernelEvent
>
KernelEvent
::
Create
(
std
::
shared_ptr
<
KernelEvent
>
KernelEvent
::
Create
(
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>&
shape_getter
)
{
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
ShapeView
>
(
void
)
>&
shape_getter
)
{
return
std
::
shared_ptr
<
KernelEvent
>
(
new
KernelEvent
(
name
,
shape_getter
));
return
std
::
shared_ptr
<
KernelEvent
>
(
new
KernelEvent
(
name
,
shape_getter
));
}
}
std
::
string
KernelEvent
::
GetFormatedInputShapes
(
size_t
max_num_to_format
)
{
void
KernelEvent
::
RecordShape
(
const
ShapeView
&
shape
)
{
input_shapes_
.
emplace_back
(
shape
);
}
if
(
input_shapes_
.
size
()
==
0
)
{
return
"-"
;
}
std
::
vector
<
std
::
string
>
shapes_formated
(
std
::
min
(
input_shapes_
.
size
(),
max_num_to_format
));
std
::
string
KernelEvent
::
GetFormatedInputShapes
(
size_t
max_num_to_format
)
{
for
(
auto
i
=
0
;
i
<
shapes_formated
.
size
();
++
i
)
{
if
(
input_shapes_
.
size
()
==
0
)
{
return
"-"
;
}
const
std
::
string
current_shape
=
input_shapes_
[
i
].
ToString
();
std
::
vector
<
std
::
string
>
shapes_formated
(
std
::
min
(
input_shapes_
.
size
(),
max_num_to_format
));
shapes_formated
[
i
]
=
current_shape
==
"()"
?
"scalar"
:
current_shape
;
for
(
auto
i
=
0
;
i
<
shapes_formated
.
size
();
++
i
)
{
}
const
std
::
string
current_shape
=
input_shapes_
[
i
].
ToString
();
if
(
input_shapes_
.
size
()
>
max_num_to_format
)
{
shapes_formated
.
emplace_back
(
"..."
);
}
shapes_formated
[
i
]
=
current_shape
==
"()"
?
"scalar"
:
current_shape
;
return
fmt
::
format
(
"[{}]"
,
fmt
::
join
(
shapes_formated
,
", "
));
}
}
if
(
input_shapes_
.
size
()
>
max_num_to_format
)
{
shapes_formated
.
emplace_back
(
"..."
);
}
// return fmt::format("[{}]", fmt::join(shapes_formated, ", "));
}
// namespace profiler
return
"yuguo"
;
}
}
// namespace profiler
}
// namespace oneflow
}
// namespace oneflow
\ No newline at end of file
oneflow/core/profiler/event.h
View file @
f262efc9
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_
#include <functional>
#include <functional>
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include "nlohmann/json.hpp"
#include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/shape_view.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
profiler
{
namespace
profiler
{
class
ProfileManager
;
class
ProfileManager
;
enum
class
EventType
{
enum
class
EventType
{
kCustom
,
// has three kinds
kCustom
,
// has three kinds
kOneflowKernel
// OneFlow cpu/cuda kernel
kOneflowKernel
// OneFlow cpu/cuda kernel
};
};
enum
class
CustomEventType
{
enum
class
CustomEventType
{
kDefault
,
// for record_function
kDefault
,
// for record_function
kCudaKernel
,
// cuda kernel
kCudaKernel
,
// cuda kernel
kCudaRuntime
// something like cudaLaunchKernel
kCudaRuntime
// something like cudaLaunchKernel
};
};
enum
class
EventTimeUnit
{
kNS
,
kUS
};
enum
class
EventTimeUnit
{
kNS
,
kUS
};
class
IEvent
{
class
IEvent
{
public:
public:
OF_DISALLOW_COPY_AND_MOVE
(
IEvent
);
OF_DISALLOW_COPY_AND_MOVE
(
IEvent
);
IEvent
()
=
delete
;
IEvent
()
=
delete
;
IEvent
(
const
std
::
string
&
name
,
EventTimeUnit
time_unit
)
:
name_
(
name
),
time_unit_
(
time_unit
)
{}
IEvent
(
const
std
::
string
&
name
,
EventTimeUnit
time_unit
)
:
name_
(
name
),
time_unit_
(
time_unit
)
{}
virtual
std
::
string
Key
()
=
0
;
virtual
std
::
string
Key
()
=
0
;
virtual
nlohmann
::
json
ToJson
();
virtual
nlohmann
::
json
ToJson
();
virtual
~
IEvent
()
=
default
;
virtual
~
IEvent
()
=
default
;
virtual
void
Start
();
virtual
void
Start
();
virtual
void
Finish
();
virtual
void
Finish
();
bool
IsChildOf
(
const
IEvent
*
e
);
bool
IsChildOf
(
const
IEvent
*
e
);
const
std
::
string
&
GetName
()
const
;
const
std
::
string
&
GetName
()
const
;
template
<
typename
T
>
template
<
typename
T
>
const
T
GetDuration
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
const
T
GetDuration
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
template
<
typename
T
>
template
<
typename
T
>
const
T
GetStartedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
const
T
GetStartedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
template
<
typename
T
>
template
<
typename
T
>
const
T
GetFinishedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
const
T
GetFinishedAt
(
EventTimeUnit
time_unit
=
EventTimeUnit
::
kUS
)
const
;
protected:
protected:
virtual
void
SetStartedAt
(
double
t
);
virtual
void
SetStartedAt
(
double
t
);
virtual
void
SetFinishedAt
(
double
t
);
virtual
void
SetFinishedAt
(
double
t
);
std
::
string
name_
;
std
::
string
name_
;
EventTimeUnit
time_unit_
;
EventTimeUnit
time_unit_
;
double
started_at_
=
0
;
double
started_at_
=
0
;
double
finished_at_
=
0
;
double
finished_at_
=
0
;
};
};
inline
double
ConvertTime
(
double
time_
,
EventTimeUnit
src_time_unit
,
EventTimeUnit
dst_time_unit
)
{
inline
double
ConvertTime
(
double
time_
,
EventTimeUnit
src_time_unit
,
EventTimeUnit
dst_time_unit
)
{
if
(
src_time_unit
==
EventTimeUnit
::
kNS
&&
dst_time_unit
==
EventTimeUnit
::
kUS
)
{
if
(
src_time_unit
==
EventTimeUnit
::
kNS
&&
dst_time_unit
==
EventTimeUnit
::
kUS
)
{
return
time_
/
1000
;
return
time_
/
1000
;
}
}
if
(
src_time_unit
==
EventTimeUnit
::
kUS
&&
dst_time_unit
==
EventTimeUnit
::
kNS
)
{
if
(
src_time_unit
==
EventTimeUnit
::
kUS
&&
dst_time_unit
==
EventTimeUnit
::
kNS
)
{
return
time_
*
1000
;
return
time_
*
1000
;
}
}
return
time_
;
return
time_
;
}
}
template
<
>
template
<
>
const
inline
double
IEvent
::
GetStartedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
const
inline
double
IEvent
::
GetStartedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
ConvertTime
(
started_at_
,
time_unit_
,
time_unit
);
return
ConvertTime
(
started_at_
,
time_unit_
,
time_unit
);
}
}
template
<
>
template
<
>
const
inline
time_t
IEvent
::
GetStartedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
const
inline
time_t
IEvent
::
GetStartedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetStartedAt
<
double
>
(
time_unit
));
return
static_cast
<
time_t
>
(
GetStartedAt
<
double
>
(
time_unit
));
}
}
template
<
>
template
<
>
const
inline
double
IEvent
::
GetFinishedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
const
inline
double
IEvent
::
GetFinishedAt
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
ConvertTime
(
finished_at_
,
time_unit_
,
time_unit
);
return
ConvertTime
(
finished_at_
,
time_unit_
,
time_unit
);
}
}
template
<
>
template
<
>
const
inline
time_t
IEvent
::
GetFinishedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
const
inline
time_t
IEvent
::
GetFinishedAt
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetFinishedAt
<
double
>
(
time_unit
));
return
static_cast
<
time_t
>
(
GetFinishedAt
<
double
>
(
time_unit
));
}
}
template
<
>
template
<
>
const
inline
double
IEvent
::
GetDuration
<
double
>
(
EventTimeUnit
time_unit
)
const
{
const
inline
double
IEvent
::
GetDuration
<
double
>
(
EventTimeUnit
time_unit
)
const
{
return
GetFinishedAt
<
double
>
(
time_unit
)
-
GetStartedAt
<
double
>
(
time_unit
);
return
GetFinishedAt
<
double
>
(
time_unit
)
-
GetStartedAt
<
double
>
(
time_unit
);
}
}
template
<
>
template
<
>
const
inline
time_t
IEvent
::
GetDuration
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
const
inline
time_t
IEvent
::
GetDuration
<
time_t
>
(
EventTimeUnit
time_unit
)
const
{
return
static_cast
<
time_t
>
(
GetDuration
<
double
>
(
time_unit
));
return
static_cast
<
time_t
>
(
GetDuration
<
double
>
(
time_unit
));
}
}
class
CustomEvent
final
:
public
IEvent
{
class
CustomEvent
final
:
public
IEvent
{
public:
public:
friend
class
ProfileManager
;
friend
class
ProfileManager
;
std
::
string
Key
()
override
;
std
::
string
Key
()
override
;
nlohmann
::
json
ToJson
()
override
;
nlohmann
::
json
ToJson
()
override
;
static
std
::
shared_ptr
<
CustomEvent
>
Create
(
const
std
::
string
&
name
,
static
std
::
shared_ptr
<
CustomEvent
>
Create
(
const
std
::
string
&
name
,
CustomEventType
type
=
CustomEventType
::
kDefault
);
CustomEventType
type
=
CustomEventType
::
kDefault
);
private:
private:
CustomEventType
type_
;
CustomEventType
type_
;
CustomEvent
(
const
std
::
string
&
custom_name
,
CustomEventType
type
)
CustomEvent
(
const
std
::
string
&
custom_name
,
CustomEventType
type
)
:
IEvent
(
custom_name
,
:
IEvent
(
custom_name
,
type
==
CustomEventType
::
kDefault
?
EventTimeUnit
::
kNS
:
EventTimeUnit
::
kUS
),
type
==
CustomEventType
::
kDefault
?
EventTimeUnit
::
kNS
:
EventTimeUnit
::
kUS
),
type_
(
type
)
{}
type_
(
type
)
{}
};
};
class
KernelEvent
final
:
public
IEvent
{
class
KernelEvent
final
:
public
IEvent
{
public:
public:
std
::
string
Key
()
override
;
std
::
string
Key
()
override
;
nlohmann
::
json
ToJson
()
override
;
nlohmann
::
json
ToJson
()
override
;
static
std
::
shared_ptr
<
KernelEvent
>
Create
(
static
std
::
shared_ptr
<
KernelEvent
>
Create
(
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
ShapeView
>
(
void
)
>&
shape_getter
);
const
std
::
string
&
name
,
const
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>&
shape_getter
);
void
RecordShape
(
const
ShapeView
&
shape
);
#if defined(WITH_CUDA) || defined(WITH_ROCM)
void
SetMemorySize
(
int64_t
memory_size
)
{
memory_size_
=
memory_size
;
}
#if defined(WITH_CUDA)
void
AddChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
children_
.
emplace
(
e
);
}
void
SetMemorySize
(
int64_t
memory_size
)
{
memory_size_
=
memory_size
;
}
bool
AddChildEventIfSo
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
void
AddChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
children_
.
emplace
(
e
);
}
if
(
e
->
IsChildOf
(
dynamic_cast
<
IEvent
*>
(
this
)))
{
bool
AddChildEventIfSo
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
children_
.
emplace
(
e
);
if
(
e
->
IsChildOf
(
dynamic_cast
<
IEvent
*>
(
this
)))
{
return
true
;
children_
.
emplace
(
e
);
}
return
true
;
return
false
;
}
}
return
false
;
bool
HasChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
return
children_
.
count
(
e
);
}
}
void
WalkAmongChildren
(
const
std
::
function
<
void
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
>&
f
)
const
{
bool
HasChildEvent
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
{
return
children_
.
count
(
e
);
}
for
(
const
auto
&
x
:
children_
)
{
f
(
x
);
}
void
WalkAmongChildren
(
const
std
::
function
<
void
(
const
std
::
shared_ptr
<
IEvent
>&
e
)
>&
f
)
const
{
}
for
(
const
auto
&
x
:
children_
)
{
f
(
x
);
}
#endif // WITH_CUDA
}
#endif // WITH_CUDA
private:
KernelEvent
(
const
std
::
string
&
kernel_name
,
private:
const
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>&
shape_getter
)
KernelEvent
(
const
std
::
string
&
kernel_name
,
:
IEvent
(
kernel_name
,
EventTimeUnit
::
kNS
)
{
const
std
::
function
<
std
::
vector
<
ShapeView
>
(
void
)
>&
shape_getter
)
if
(
shape_getter
)
{
input_shapes_
=
shape_getter
();
}
:
IEvent
(
kernel_name
,
EventTimeUnit
::
kNS
)
{
}
if
(
shape_getter
)
{
input_shapes_
=
shape_getter
();
}
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
int64_t
memory_size_
=
-
1
;
#if defined(WITH_CUDA)
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
children_
;
int64_t
memory_size_
=
-
1
;
#endif // WITH_CUDA
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
children_
;
#endif // WITH_CUDA
std
::
vector
<
Shape
>
input_shapes_
;
std
::
string
GetFormatedInputShapes
(
size_t
max_num_to_format
=
4
);
std
::
vector
<
ShapeView
>
input_shapes_
;
};
std
::
string
GetFormatedInputShapes
(
size_t
max_num_to_format
=
4
);
};
}
// namespace profiler
}
// namespace oneflow
}
// namespace profiler
}
// namespace oneflow
namespace
nlohmann
{
namespace
nlohmann
{
inline
void
to_json
(
json
&
j
,
const
std
::
shared_ptr
<::
oneflow
::
profiler
::
IEvent
>&
event
)
{
j
=
event
->
ToJson
();
inline
void
to_json
(
json
&
j
,
const
std
::
shared_ptr
<::
oneflow
::
profiler
::
IEvent
>&
event
)
{
}
j
=
event
->
ToJson
();
}
}
// namespace nlohmann
}
// namespace nlohmann
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
oneflow/core/profiler/event_recorder.cpp
View file @
f262efc9
...
@@ -32,13 +32,13 @@ std::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const st
...
@@ -32,13 +32,13 @@ std::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const st
Maybe
<
EventRecorder
>
EventRecorder
::
CreateKernelEventRecorder
(
Maybe
<
EventRecorder
>
EventRecorder
::
CreateKernelEventRecorder
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
#endif
#endif
const
ShapeGetterFuncType
&
shape_getter
)
{
const
ShapeGetterFuncType
&
shape_getter
)
{
auto
pmgr
=
Singleton
<
ProfileManager
>::
Get
();
auto
pmgr
=
Singleton
<
ProfileManager
>::
Get
();
if
(
pmgr
)
{
if
(
pmgr
)
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
if
(
pmgr
->
use_cpu_
||
pmgr
->
use_cuda_
)
{
if
(
pmgr
->
use_cpu_
||
pmgr
->
use_cuda_
)
{
auto
event
=
KernelEvent
::
Create
(
name
,
pmgr
->
record_shapes_
?
shape_getter
:
nullptr
);
auto
event
=
KernelEvent
::
Create
(
name
,
pmgr
->
record_shapes_
?
shape_getter
:
nullptr
);
if
(
pmgr
->
use_cuda_
)
{
if
(
pmgr
->
use_cuda_
)
{
...
...
oneflow/core/profiler/event_recorder.h
View file @
f262efc9
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/event.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
profiler
{
namespace
profiler
{
class
EventRecorder
{
class
EventRecorder
{
public:
public:
using
ShapeGetterFuncType
=
std
::
function
<
std
::
vector
<
Shape
View
>
(
void
)
>
;
using
ShapeGetterFuncType
=
std
::
function
<
std
::
vector
<
Shape
>
(
void
)
>
;
OF_DISALLOW_COPY_AND_MOVE
(
EventRecorder
);
OF_DISALLOW_COPY_AND_MOVE
(
EventRecorder
);
explicit
EventRecorder
(
const
std
::
shared_ptr
<
IEvent
>&
event
)
:
event_
(
event
)
{
explicit
EventRecorder
(
const
std
::
shared_ptr
<
IEvent
>&
event
)
:
event_
(
event
)
{
CHECK_JUST
(
RegisterEventToProfileManager
(
event
));
CHECK_JUST
(
RegisterEventToProfileManager
(
event
));
event_
->
Start
();
event_
->
Start
();
}
}
Maybe
<
void
>
RegisterEventToProfileManager
(
const
std
::
shared_ptr
<
IEvent
>&
event
);
Maybe
<
void
>
RegisterEventToProfileManager
(
const
std
::
shared_ptr
<
IEvent
>&
event
);
~
EventRecorder
()
{
~
EventRecorder
()
{
if
(
event_
)
{
if
(
event_
)
{
event_
->
Finish
();
event_
->
Finish
();
event_
.
reset
();
event_
.
reset
();
}
}
}
}
static
std
::
shared_ptr
<
EventRecorder
>
CreateCustomEventRecorder
(
const
std
::
string
&
name
);
static
std
::
shared_ptr
<
EventRecorder
>
CreateCustomEventRecorder
(
const
std
::
string
&
name
);
static
Maybe
<
EventRecorder
>
CreateKernelEventRecorder
(
static
Maybe
<
EventRecorder
>
CreateKernelEventRecorder
(
const
std
::
string
&
name
,
const
std
::
string
&
name
,
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
const
std
::
function
<
int64_t
()
>&
memory_size_getter
,
#endif
#endif
const
ShapeGetterFuncType
&
shape_getter
);
const
ShapeGetterFuncType
&
shape_getter
);
private:
private:
std
::
shared_ptr
<
IEvent
>
event_
;
std
::
shared_ptr
<
IEvent
>
event_
;
};
};
}
// namespace profiler
}
// namespace profiler
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
oneflow/core/profiler/kernel.cpp
View file @
f262efc9
...
@@ -17,7 +17,11 @@ limitations under the License.
...
@@ -17,7 +17,11 @@ limitations under the License.
#include "oneflow/core/profiler/kernel.h"
#include "oneflow/core/profiler/kernel.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/kernel.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/lazy/actor/actor_context.h"
#include "oneflow/core/lazy/actor/actor_context.h"
namespace
oneflow
{
namespace
oneflow
{
...
@@ -43,6 +47,11 @@ thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
...
@@ -43,6 +47,11 @@ thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
thread_local
cudaEvent_t
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
thread_local
cudaEvent_t
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
#endif // WITH_CUDA
#endif // WITH_CUDA
#if defined(WITH_ROCM)
thread_local
hipEvent_t
cuda_memory_bandwidth_profile_start_event
=
nullptr
;
thread_local
hipEvent_t
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
#endif // WITH_ROCM
}
// namespace
}
// namespace
void
TraceKernelForwardDataContentStart
(
KernelContext
*
kernel_ctx
,
const
Kernel
*
kernel
)
{
void
TraceKernelForwardDataContentStart
(
KernelContext
*
kernel_ctx
,
const
Kernel
*
kernel
)
{
...
@@ -61,6 +70,22 @@ void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel*
...
@@ -61,6 +70,22 @@ void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel*
}
}
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_PUSH
(
kernel
->
op_conf
().
name
());
}
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_PUSH
(
kernel
->
op_conf
().
name
());
}
#endif // WITH_CUDA
#endif // WITH_CUDA
#if defined(WITH_ROCM)
if
(
profile_cuda_memory_bandwidth
)
{
auto
*
actor_context_provider
=
dynamic_cast
<
ActorContextProvider
*>
(
kernel_ctx
);
auto
*
cuda_stream
=
dynamic_cast
<
ep
::
CudaStream
*>
(
kernel_ctx
->
stream
());
if
(
cuda_stream
!=
nullptr
&&
actor_context_provider
!=
nullptr
)
{
CHECK
(
cuda_memory_bandwidth_profile_start_event
==
nullptr
);
CHECK
(
cuda_memory_bandwidth_profile_end_event
==
nullptr
);
OF_CUDA_CHECK
(
hipEventCreate
(
&
cuda_memory_bandwidth_profile_start_event
));
OF_CUDA_CHECK
(
hipEventCreate
(
&
cuda_memory_bandwidth_profile_end_event
));
OF_CUDA_CHECK
(
hipEventRecord
(
cuda_memory_bandwidth_profile_start_event
,
cuda_stream
->
cuda_stream
()));
}
}
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_PUSH
(
kernel
->
op_conf
().
name
());
}
#endif // WITH_ROCM
}
}
void
TraceKernelForwardDataContentEnd
(
KernelContext
*
kernel_ctx
,
const
Kernel
*
kernel
)
{
void
TraceKernelForwardDataContentEnd
(
KernelContext
*
kernel_ctx
,
const
Kernel
*
kernel
)
{
...
@@ -103,6 +128,45 @@ void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* k
...
@@ -103,6 +128,45 @@ void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* k
}
}
}
}
#endif // WITH_CUDA
#endif // WITH_CUDA
#if defined(WITH_ROCM)
if
(
profile_kernel_forward_range
)
{
OF_PROFILER_RANGE_POP
();
}
// The memory bandwidth profiler only works in lazy mode.
if
(
profile_cuda_memory_bandwidth
)
{
auto
*
cuda_stream
=
dynamic_cast
<
ep
::
CudaStream
*>
(
kernel_ctx
->
stream
());
auto
*
actor_context_provider
=
dynamic_cast
<
ActorContextProvider
*>
(
kernel_ctx
);
if
(
cuda_stream
!=
nullptr
&&
actor_context_provider
!=
nullptr
)
{
hipEvent_t
start_event
=
cuda_memory_bandwidth_profile_start_event
;
hipEvent_t
end_event
=
cuda_memory_bandwidth_profile_end_event
;
cuda_memory_bandwidth_profile_start_event
=
nullptr
;
cuda_memory_bandwidth_profile_end_event
=
nullptr
;
CHECK_NOTNULL
(
start_event
);
CHECK_NOTNULL
(
end_event
);
OF_CUDA_CHECK
(
hipEventRecord
(
end_event
,
cuda_stream
->
cuda_stream
()));
int64_t
memory_size
=
0
;
for
(
const
auto
&
bn
:
kernel
->
op_attribute
().
input_bns
())
{
const
Blob
*
blob
=
kernel_ctx
->
BnInOp2Blob
(
bn
);
if
(
blob
)
{
memory_size
+=
blob
->
ByteSizeOfBlobBody
();
}
}
for
(
const
auto
&
bn
:
kernel
->
op_attribute
().
output_bns
())
{
const
Blob
*
blob
=
kernel_ctx
->
BnInOp2Blob
(
bn
);
if
(
blob
)
{
memory_size
+=
blob
->
ByteSizeOfBlobBody
();
}
}
const
std
::
string
op_name
=
kernel
->
op_conf
().
name
();
actor_context_provider
->
GetActorContext
()
->
AddCallback
(
[
start_event
,
end_event
,
memory_size
,
op_name
]()
{
float
elapsed_ms
=
0
;
OF_CUDA_CHECK
(
hipEventElapsedTime
(
&
elapsed_ms
,
start_event
,
end_event
));
OF_CUDA_CHECK
(
hipEventDestroy
(
start_event
));
OF_CUDA_CHECK
(
hipEventDestroy
(
end_event
));
double
bandwidth
=
static_cast
<
double
>
(
memory_size
)
/
(
1024.0
*
1024.0
*
1024.0
)
/
(
elapsed_ms
/
1000
);
LOG
(
INFO
)
<<
"PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: "
<<
op_name
<<
" elapsed(ms): "
<<
elapsed_ms
<<
" memory_size(Byte): "
<<
memory_size
<<
" bandwidth(GB/s): "
<<
bandwidth
;
});
}
}
#endif // WITH_ROCM
}
}
}
// namespace profiler
}
// namespace profiler
...
...
oneflow/core/profiler/kineto_shim.cpp
View file @
f262efc9
...
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
...
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
#include "oneflow/core/profiler/kineto_shim.h"
#include "oneflow/core/profiler/kineto_shim.h"
#include "libkineto.h"
#include "libkineto.h"
...
...
oneflow/core/profiler/kineto_shim.h
View file @
f262efc9
...
@@ -16,7 +16,7 @@ limitations under the License.
...
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
#include <string>
#include <string>
#include <memory>
#include <memory>
...
...
oneflow/core/profiler/profile_manager.cpp
View file @
f262efc9
...
@@ -15,12 +15,12 @@ limitations under the License.
...
@@ -15,12 +15,12 @@ limitations under the License.
*/
*/
#include <memory>
#include <memory>
#include <unordered_map>
#include <unordered_map>
//
#include "fmt/core.h"
#include "fmt/core.h"
#include "nlohmann/json.hpp"
#include "nlohmann/json.hpp"
#include "oneflow/core/profiler/kineto_shim.h"
#include "oneflow/core/profiler/kineto_shim.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/event.h"
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
#include <libkineto.h>
#include <libkineto.h>
#endif // WITH_CUDA
#endif // WITH_CUDA
...
@@ -48,7 +48,7 @@ std::string ProfileManager::DumpResultsJson() {
...
@@ -48,7 +48,7 @@ std::string ProfileManager::DumpResultsJson() {
}
}
std
::
vector
<
std
::
shared_ptr
<
IEvent
>>
ProfileManager
::
ExportEvents
()
{
std
::
vector
<
std
::
shared_ptr
<
IEvent
>>
ProfileManager
::
ExportEvents
()
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
auto
trace
=
StopTrace
();
auto
trace
=
StopTrace
();
const
auto
&
kineto_events
=
*
(
trace
.
get
()
->
activities
());
const
auto
&
kineto_events
=
*
(
trace
.
get
()
->
activities
());
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
custom_events
;
std
::
set
<
std
::
shared_ptr
<
IEvent
>>
custom_events
;
...
@@ -77,7 +77,7 @@ std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
...
@@ -77,7 +77,7 @@ std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
while
(
!
events_
.
empty
())
{
while
(
!
events_
.
empty
())
{
auto
evt
=
events_
.
front
();
auto
evt
=
events_
.
front
();
events_
.
pop
();
events_
.
pop
();
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
auto
evt_kernel
=
std
::
dynamic_pointer_cast
<
KernelEvent
>
(
evt
);
auto
evt_kernel
=
std
::
dynamic_pointer_cast
<
KernelEvent
>
(
evt
);
if
(
evt_kernel
)
{
if
(
evt_kernel
)
{
std
::
set
<
int64_t
>
current_corr_ids
;
std
::
set
<
int64_t
>
current_corr_ids
;
...
@@ -106,8 +106,7 @@ std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) {
...
@@ -106,8 +106,7 @@ std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) {
}
else
{
}
else
{
event_recorders_last_id_
[
name
]
++
;
event_recorders_last_id_
[
name
]
++
;
}
}
// return fmt::format("{}.{}", name, event_recorders_last_id_[name]);
return
fmt
::
format
(
"{}.{}"
,
name
,
event_recorders_last_id_
[
name
]);
return
"yuguo"
;
}
}
}
// namespace profiler
}
// namespace profiler
...
...
oneflow/core/profiler/profile_manager.h
View file @
f262efc9
...
@@ -37,7 +37,7 @@ class ProfileManager {
...
@@ -37,7 +37,7 @@ class ProfileManager {
use_cuda_
(
use_cuda
),
use_cuda_
(
use_cuda
),
record_shapes_
(
record_shapes
),
record_shapes_
(
record_shapes
),
record_bandwidth_
(
record_bandwidth
)
{
record_bandwidth_
(
record_bandwidth
)
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
std
::
set
<
ActivityType
>
activities
{};
std
::
set
<
ActivityType
>
activities
{};
if
(
use_cpu
)
{
activities
.
insert
(
ActivityType
::
CPU
);
}
if
(
use_cpu
)
{
activities
.
insert
(
ActivityType
::
CPU
);
}
if
(
use_cuda
)
{
activities
.
insert
(
ActivityType
::
CUDA
);
}
if
(
use_cuda
)
{
activities
.
insert
(
ActivityType
::
CUDA
);
}
...
...
oneflow/core/profiler/profiler.cpp
View file @
f262efc9
...
@@ -20,11 +20,20 @@ limitations under the License.
...
@@ -20,11 +20,20 @@ limitations under the License.
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/vm_util.h"
#ifdef OF_ENABLE_PROFILER
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_profile.h>
#include <roctracer_roctx.h>
#include <sys/syscall.h>
#include <iostream>
#include "oneflow/core/device/cuda_util.h"
#else
#include <nvtx3/nvToolsExt.h>
#include <nvtx3/nvToolsExt.h>
#include <sys/syscall.h>
#include <sys/syscall.h>
#include <iostream>
#include <iostream>
#include <cuda_profiler_api.h>
#include <cuda_profiler_api.h>
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/device/cuda_util.h"
#endif
#endif // OF_ENABLE_PROFILER
#endif // OF_ENABLE_PROFILER
namespace
oneflow
{
namespace
oneflow
{
...
@@ -33,6 +42,16 @@ namespace profiler {
...
@@ -33,6 +42,16 @@ namespace profiler {
void
NameThisHostThread
(
const
std
::
string
&
name
)
{
void
NameThisHostThread
(
const
std
::
string
&
name
)
{
#ifdef OF_ENABLE_PROFILER
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
static
thread_local
std
::
unique_ptr
<
std
::
string
>
thread_name_prefix
;
if
(
!
thread_name_prefix
)
{
thread_name_prefix
.
reset
(
new
std
::
string
(
GetStringFromEnv
(
"ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX"
,
""
)));
}
const
std
::
string
name_with_prefix
=
*
thread_name_prefix
+
name
;
// nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());
roctxMarkA
(
name_with_prefix
.
c_str
());
#else
static
thread_local
std
::
unique_ptr
<
std
::
string
>
thread_name_prefix
;
static
thread_local
std
::
unique_ptr
<
std
::
string
>
thread_name_prefix
;
if
(
!
thread_name_prefix
)
{
if
(
!
thread_name_prefix
)
{
thread_name_prefix
.
reset
(
thread_name_prefix
.
reset
(
...
@@ -40,18 +59,27 @@ void NameThisHostThread(const std::string& name) {
...
@@ -40,18 +59,27 @@ void NameThisHostThread(const std::string& name) {
}
}
const
std
::
string
name_with_prefix
=
*
thread_name_prefix
+
name
;
const
std
::
string
name_with_prefix
=
*
thread_name_prefix
+
name
;
nvtxNameOsThreadA
(
syscall
(
SYS_gettid
),
name_with_prefix
.
c_str
());
nvtxNameOsThreadA
(
syscall
(
SYS_gettid
),
name_with_prefix
.
c_str
());
#endif
#endif // OF_ENABLE_PROFILER
#endif // OF_ENABLE_PROFILER
}
}
void
RangePush
(
const
std
::
string
&
name
)
{
void
RangePush
(
const
std
::
string
&
name
)
{
#ifdef OF_ENABLE_PROFILER
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePushA
(
name
.
c_str
());
#else
nvtxRangePushA
(
name
.
c_str
());
nvtxRangePushA
(
name
.
c_str
());
#endif
#endif // OF_ENABLE_PROFILER
#endif // OF_ENABLE_PROFILER
}
}
void
RangePop
()
{
void
RangePop
()
{
#ifdef OF_ENABLE_PROFILER
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePop
();
#else
nvtxRangePop
();
nvtxRangePop
();
#endif
#endif // OF_ENABLE_PROFILER
#endif // OF_ENABLE_PROFILER
}
}
...
@@ -82,13 +110,21 @@ void LogHostMemoryUsage(const std::string& name) {
...
@@ -82,13 +110,21 @@ void LogHostMemoryUsage(const std::string& name) {
void
ProfilerStart
()
{
void
ProfilerStart
()
{
#ifdef OF_ENABLE_PROFILER
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipProfilerStart
());
#else
OF_CUDA_CHECK
(
cudaProfilerStart
());
OF_CUDA_CHECK
(
cudaProfilerStart
());
#endif
#endif // OF_ENABLE_PROFILER
#endif // OF_ENABLE_PROFILER
}
}
void
ProfilerStop
()
{
void
ProfilerStop
()
{
#ifdef OF_ENABLE_PROFILER
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK
(
hipProfilerStop
());
#else
OF_CUDA_CHECK
(
cudaProfilerStop
());
OF_CUDA_CHECK
(
cudaProfilerStop
());
#endif
#endif // OF_ENABLE_PROFILER
#endif // OF_ENABLE_PROFILER
}
}
...
@@ -105,6 +141,9 @@ Maybe<std::string> DisableProfilerAndReturnResult() {
...
@@ -105,6 +141,9 @@ Maybe<std::string> DisableProfilerAndReturnResult() {
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
OF_CUDA_CHECK
(
cudaDeviceSynchronize
());
OF_CUDA_CHECK
(
cudaDeviceSynchronize
());
#endif // WITH_CUDA
#endif // WITH_CUDA
#if defined(WITH_ROCM)
OF_CUDA_CHECK
(
hipDeviceSynchronize
());
#endif // WITH_ROCM
auto
*
pmgr
=
JUST
(
SingletonMaybe
<
ProfileManager
>
());
auto
*
pmgr
=
JUST
(
SingletonMaybe
<
ProfileManager
>
());
std
::
string
results
=
pmgr
->
DumpResultsJson
();
std
::
string
results
=
pmgr
->
DumpResultsJson
();
Singleton
<
ProfileManager
>::
Delete
();
Singleton
<
ProfileManager
>::
Delete
();
...
...
oneflow/user/kernels/math_unary_elementwise_func.h
View file @
f262efc9
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/core/device/cuda_pseudo_half.h"
#include "oneflow/core/device/cuda_pseudo_half.h"
#if defined(__CUDACC__)
#if defined(__CUDACC__)
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#define MATH_FUNC_D(name, x) name(x)
#elif defined(__HIPCC__)
#elif defined(__HIPCC__)
#include <cmath>
#include <cmath>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__HIP_DEVICE_COMPILE__)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#define MATH_FUNC_D(name, x) name(x)
#else
#else
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
#endif
#else
#else
#include <cmath>
#include <cmath>
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
#endif
namespace
oneflow
{
namespace
oneflow
{
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
template<typename T> \
template
<
typename
T
>
\
struct func_prefix##Functor;
struct
func_prefix
##
Functor
;
OF_PP_FOR_EACH_TUPLE
(
DECLARE_UNARY_FUNCTOR
,
MATH_UNARY_ELEMENTWISE_FUNC_SEQ
)
OF_PP_FOR_EACH_TUPLE
(
DECLARE_UNARY_FUNCTOR
,
MATH_UNARY_ELEMENTWISE_FUNC_SEQ
)
template
<
typename
T
>
template
<
typename
T
>
struct
AbsFunctor
{
struct
AbsFunctor
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
if
(
x
==
T
(
0
))
if
(
x
==
T
(
0
))
return
T
(
0
);
return
T
(
0
);
else
else
return
x
<
T
(
0
)
?
-
x
:
x
;
return
x
<
T
(
0
)
?
-
x
:
x
;
}
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
if
(
x
==
T
(
0
))
if
(
x
==
T
(
0
))
return
T
(
0
);
return
T
(
0
);
else
else
return
x
<
T
(
0
)
?
-
dy
:
dy
;
return
x
<
T
(
0
)
?
-
dy
:
dy
;
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
SignFunctor
{
struct
SignFunctor
{
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
return
(
T
(
0
)
<
x
)
-
(
x
<
T
(
0
));
}
static
OF_DEVICE_FUNC
T
Forward
(
const
T
x
)
{
return
(
T
(
0
)
<
x
)
-
(
x
<
T
(
0
));
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
return
T
(
0
);
}
static
OF_DEVICE_FUNC
T
Backward
(
const
T
x
,
const
T
dy
)
{
return
T
(
0
);
}
};
};
template
<
>
template
<
>
struct
RsqrtFunctor
<
float
>
{
struct
RsqrtFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
#if defined(__CUDACC__)
#if defined(__CUDACC__)
return
rsqrtf
(
x
);
return
rsqrtf
(
x
);
#elif defined(__HIP_DEVICE_COMPILE__)
#elif defined(__HIP_DEVICE_COMPILE__)
return
rsqrtf
(
x
);
return
rsqrtf
(
x
);
#else
#else
return
1.0
f
/
std
::
sqrt
(
x
);
return
1.0
f
/
std
::
sqrt
(
x
);
#endif
#endif
}
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
1.0
f
/
(
2.0
f
*
MATH_FUNC_F
(
sqrt
,
x
*
x
*
x
)));
return
dy
*
(
-
1.0
f
/
(
2.0
f
*
MATH_FUNC_F
(
sqrt
,
x
*
x
*
x
)));
}
}
};
};
template
<
>
template
<
>
struct
RsqrtFunctor
<
double
>
{
struct
RsqrtFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
#if defined(__CUDACC__)
#if defined(__CUDACC__)
return
rsqrt
(
x
);
return
rsqrt
(
x
);
#elif defined(__HIP_DEVICE_COMPILE__)
#elif defined(__HIP_DEVICE_COMPILE__)
return
rsqrt
(
x
);
return
rsqrt
(
x
);
#else
#else
return
1.0
/
std
::
sqrt
(
x
);
return
1.0
/
std
::
sqrt
(
x
);
#endif
#endif
}
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
1.0
/
(
2.0
*
MATH_FUNC_D
(
sqrt
,
x
*
x
*
x
)));
return
dy
*
(
-
1.0
/
(
2.0
*
MATH_FUNC_D
(
sqrt
,
x
*
x
*
x
)));
}
}
};
};
// float version
// float version
template
<
>
template
<
>
struct
AcosFunctor
<
float
>
{
struct
AcosFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
-
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
return
dy
*
-
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
AcoshFunctor
<
float
>
{
struct
AcoshFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
x
*
x
-
1.0
f
);
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
x
*
x
-
1.0
f
);
}
}
};
};
template
<
>
template
<
>
struct
AsinFunctor
<
float
>
{
struct
AsinFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
AsinhFunctor
<
float
>
{
struct
AsinhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
+
x
*
x
);
return
dy
*
RsqrtFunctor
<
float
>::
Forward
(
1.0
f
+
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
AtanFunctor
<
float
>
{
struct
AtanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
1.0
f
+
x
*
x
));
return
dy
*
(
1.0
f
/
(
1.0
f
+
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
AtanhFunctor
<
float
>
{
struct
AtanhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
1.0
f
-
x
*
x
));
return
dy
*
(
1.0
f
/
(
1.0
f
-
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
NotEqualZeroFunctor
<
float
>
{
struct
NotEqualZeroFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
};
template
<
>
template
<
>
struct
CeilFunctor
<
float
>
{
struct
CeilFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
};
template
<
>
template
<
>
struct
CosFunctor
<
float
>
{
struct
CosFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
MATH_FUNC_F
(
sin
,
x
));
return
dy
*
(
-
MATH_FUNC_F
(
sin
,
x
));
}
}
};
};
template
<
>
template
<
>
struct
CoshFunctor
<
float
>
{
struct
CoshFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
sinh
,
x
);
return
dy
*
MATH_FUNC_F
(
sinh
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
ErfFunctor
<
float
>
{
struct
ErfFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
return
dy
*
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
ErfcFunctor
<
float
>
{
struct
ErfcFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
-
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
return
dy
*
-
2.0
f
*
RsqrtFunctor
<
float
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
ExpFunctor
<
float
>
{
struct
ExpFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
Expm1Functor
<
float
>
{
struct
Expm1Functor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
return
dy
*
MATH_FUNC_F
(
exp
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
FloorFunctor
<
float
>
{
struct
FloorFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
};
template
<
>
template
<
>
struct
LgammaFunctor
<
float
>
{
struct
LgammaFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
//
assert(false);
return
0.0
f
;
return
0.0
f
;
}
}
};
};
template
<
>
template
<
>
struct
LogFunctor
<
float
>
{
struct
LogFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
x
);
}
};
};
template
<
>
template
<
>
struct
Log2Functor
<
float
>
{
struct
Log2Functor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
x
*
MATH_FUNC_F
(
log
,
2.0
f
)));
return
dy
*
(
1.0
f
/
(
x
*
MATH_FUNC_F
(
log
,
2.0
f
)));
}
}
};
};
template
<
>
template
<
>
struct
Log1pFunctor
<
float
>
{
struct
Log1pFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
x
+
1.0
f
));
return
dy
*
(
1.0
f
/
(
x
+
1.0
f
));
}
}
};
};
template
<
>
template
<
>
struct
LogSigmoidFunctor
<
float
>
{
struct
LogSigmoidFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
MATH_FUNC_F
(
log
,
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
)));
return
-
MATH_FUNC_F
(
log
,
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
)));
}
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
exp
,
x
)
+
1.0
f
));
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
exp
,
x
)
+
1.0
f
));
}
}
};
};
template
<
>
template
<
>
struct
NegativeFunctor
<
float
>
{
struct
NegativeFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
-
dy
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
-
dy
;
}
};
};
template
<
>
template
<
>
struct
ReciprocalFunctor
<
float
>
{
struct
ReciprocalFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
x
;
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
ReciprocalNoNanFunctor
<
float
>
{
struct
ReciprocalNoNanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
return
1.0
f
/
x
;
return
1.0
f
/
x
;
}
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
if
(
fabsf
(
x
)
<=
0.0
f
)
{
return
0.0
f
;
}
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
return
dy
*
(
-
1.0
f
/
(
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
RintFunctor
<
float
>
{
struct
RintFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
};
template
<
>
template
<
>
struct
RoundFunctor
<
float
>
{
struct
RoundFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
0.0
f
;
}
};
};
template
<
>
template
<
>
struct
SigmoidFunctor
<
float
>
{
struct
SigmoidFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
return
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
}
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
float
y
=
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
float
y
=
1.0
f
/
(
1.0
f
+
MATH_FUNC_F
(
exp
,
-
x
));
return
dy
*
(
y
*
(
1.0
f
-
y
));
return
dy
*
(
y
*
(
1.0
f
-
y
));
}
}
};
};
template
<
>
template
<
>
struct
SinFunctor
<
float
>
{
struct
SinFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
cos
,
x
);
return
dy
*
MATH_FUNC_F
(
cos
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
SinhFunctor
<
float
>
{
struct
SinhFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
MATH_FUNC_F
(
cosh
,
x
);
return
dy
*
MATH_FUNC_F
(
cosh
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
SqrtFunctor
<
float
>
{
struct
SqrtFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
0.5
f
/
MATH_FUNC_F
(
sqrt
,
x
);
return
dy
*
0.5
f
/
MATH_FUNC_F
(
sqrt
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
SquareFunctor
<
float
>
{
struct
SquareFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
x
;
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
2.0
f
*
x
;
}
};
};
template
<
>
template
<
>
struct
TanFunctor
<
float
>
{
struct
TanFunctor
<
float
>
{
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
float
Forward
(
const
float
x
)
{
return
MATH_FUNC_F
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
static
OF_DEVICE_FUNC
float
Backward
(
const
float
x
,
const
float
dy
)
{
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
cos
,
x
)
*
MATH_FUNC_F
(
cos
,
x
)));
return
dy
*
(
1.0
f
/
(
MATH_FUNC_F
(
cos
,
x
)
*
MATH_FUNC_F
(
cos
,
x
)));
}
}
};
};
// double version
// double version
template
<
>
template
<
>
struct
AcosFunctor
<
double
>
{
struct
AcosFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acos
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
AcoshFunctor
<
double
>
{
struct
AcoshFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
acosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
x
*
x
-
1.0
);
return
dy
*
-
RsqrtFunctor
<
double
>::
Forward
(
x
*
x
-
1.0
);
}
}
};
};
template
<
>
template
<
>
struct
AsinFunctor
<
double
>
{
struct
AsinFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asin
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
AsinhFunctor
<
double
>
{
struct
AsinhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
asinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
+
x
*
x
);
return
dy
*
RsqrtFunctor
<
double
>::
Forward
(
1.0
+
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
AtanFunctor
<
double
>
{
struct
AtanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atan
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
1.0
+
x
*
x
));
return
dy
*
(
1.0
/
(
1.0
+
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
AtanhFunctor
<
double
>
{
struct
AtanhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
atanh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
1.0
-
x
*
x
));
return
dy
*
(
1.0
/
(
1.0
-
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
NotEqualZeroFunctor
<
double
>
{
struct
NotEqualZeroFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
!=
0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
f
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
f
;
}
};
};
template
<
>
template
<
>
struct
CeilFunctor
<
double
>
{
struct
CeilFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
ceil
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
};
template
<
>
template
<
>
struct
CosFunctor
<
double
>
{
struct
CosFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cos
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
MATH_FUNC_D
(
sin
,
x
));
return
dy
*
(
-
MATH_FUNC_D
(
sin
,
x
));
}
}
};
};
template
<
>
template
<
>
struct
CoshFunctor
<
double
>
{
struct
CoshFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
cosh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
sinh
,
x
);
return
dy
*
MATH_FUNC_D
(
sinh
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
ErfFunctor
<
double
>
{
struct
ErfFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erf
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
return
dy
*
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
ErfcFunctor
<
double
>
{
struct
ErfcFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
erfc
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
-
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
return
dy
*
-
2.0
*
RsqrtFunctor
<
double
>::
Forward
(
M_PI
)
*
expf
(
-
x
*
x
);
}
}
};
};
template
<
>
template
<
>
struct
ExpFunctor
<
double
>
{
struct
ExpFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
exp
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
Expm1Functor
<
double
>
{
struct
Expm1Functor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
expm1
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
return
dy
*
MATH_FUNC_D
(
exp
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
FloorFunctor
<
double
>
{
struct
FloorFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
floor
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
};
template
<
>
template
<
>
struct
LgammaFunctor
<
double
>
{
struct
LgammaFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
lgamma
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
//
assert(false);
return
0.0
;
return
0.0
;
}
}
};
};
template
<
>
template
<
>
struct
LogFunctor
<
double
>
{
struct
LogFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
x
);
}
};
};
template
<
>
template
<
>
struct
Log2Functor
<
double
>
{
struct
Log2Functor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log2
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
x
*
MATH_FUNC_D
(
log
,
2.0
)));
return
dy
*
(
1.0
/
(
x
*
MATH_FUNC_D
(
log
,
2.0
)));
}
}
};
};
template
<
>
template
<
>
struct
Log1pFunctor
<
double
>
{
struct
Log1pFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
log1p
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
x
+
1.0
));
return
dy
*
(
1.0
/
(
x
+
1.0
));
}
}
};
};
template
<
>
template
<
>
struct
LogSigmoidFunctor
<
double
>
{
struct
LogSigmoidFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
MATH_FUNC_D
(
log
,
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
)));
return
-
MATH_FUNC_D
(
log
,
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
)));
}
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
exp
,
x
)
+
1.0
));
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
exp
,
x
)
+
1.0
));
}
}
};
};
template
<
>
template
<
>
struct
NegativeFunctor
<
double
>
{
struct
NegativeFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
-
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
-
dy
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
-
dy
;
}
};
};
template
<
>
template
<
>
struct
ReciprocalFunctor
<
double
>
{
struct
ReciprocalFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
x
;
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
-
1.0
/
(
x
*
x
));
return
dy
*
(
-
1.0
/
(
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
ReciprocalNoNanFunctor
<
double
>
{
struct
ReciprocalNoNanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
return
1.0
/
x
;
return
1.0
/
x
;
}
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
if
(
fabs
(
x
)
<=
0.0
)
{
return
0.0
;
}
return
dy
*
(
-
1.0
/
(
x
*
x
));
return
dy
*
(
-
1.0
/
(
x
*
x
));
}
}
};
};
template
<
>
template
<
>
struct
RintFunctor
<
double
>
{
struct
RintFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
rint
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
};
template
<
>
template
<
>
struct
RoundFunctor
<
double
>
{
struct
RoundFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
nearbyint
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
0.0
;
}
};
};
template
<
>
template
<
>
struct
SigmoidFunctor
<
double
>
{
struct
SigmoidFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
return
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
}
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
double
y
=
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
double
y
=
1.0
/
(
1.0
+
MATH_FUNC_D
(
exp
,
-
x
));
return
dy
*
(
y
*
(
1.0
-
y
));
return
dy
*
(
y
*
(
1.0
-
y
));
}
}
};
};
template
<
>
template
<
>
struct
SinFunctor
<
double
>
{
struct
SinFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sin
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
cos
,
x
);
return
dy
*
MATH_FUNC_D
(
cos
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
SinhFunctor
<
double
>
{
struct
SinhFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sinh
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
MATH_FUNC_D
(
cosh
,
x
);
return
dy
*
MATH_FUNC_D
(
cosh
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
SqrtFunctor
<
double
>
{
struct
SqrtFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
sqrt
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
double
)
0.5
/
MATH_FUNC_D
(
sqrt
,
x
);
return
dy
*
(
double
)
0.5
/
MATH_FUNC_D
(
sqrt
,
x
);
}
}
};
};
template
<
>
template
<
>
struct
SquareFunctor
<
double
>
{
struct
SquareFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
x
*
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
x
;
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
2.0
*
x
;
}
};
};
template
<
>
template
<
>
struct
TanFunctor
<
double
>
{
struct
TanFunctor
<
double
>
{
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
double
Forward
(
const
double
x
)
{
return
MATH_FUNC_D
(
tan
,
x
);
}
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
static
OF_DEVICE_FUNC
double
Backward
(
const
double
x
,
const
double
dy
)
{
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
cos
,
x
)
*
MATH_FUNC_D
(
cos
,
x
)));
return
dy
*
(
1.0
/
(
MATH_FUNC_D
(
cos
,
x
)
*
MATH_FUNC_D
(
cos
,
x
)));
}
}
};
};
#if defined(__CUDACC__) || defined(__HIPCC__)
#if defined(__CUDACC__) || defined(__HIPCC__)
// half version
// half version
#define OF_HALF_FUNC __device__ __forceinline__
#define OF_HALF_FUNC __device__ __forceinline__
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
template
<
>
template
<
>
struct
AbsFunctor
<
half
>
{
struct
AbsFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
x
)
:
x
;
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
x
)
:
x
;
}
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
dy
)
:
dy
;
return
__hlt
(
x
,
GetZeroVal
<
half
>
())
?
__hneg
(
dy
)
:
dy
;
}
}
};
};
template
<
>
template
<
>
struct
AcosFunctor
<
half
>
{
struct
AcosFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acos
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acos
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
)))));
return
__hmul
(
dy
,
__hneg
(
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
)))));
}
}
};
};
template
<
>
template
<
>
struct
AcoshFunctor
<
half
>
{
struct
AcoshFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acosh
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
acosh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
__hmul
(
x
,
x
),
GetOneVal
<
half
>
())));
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
__hmul
(
x
,
x
),
GetOneVal
<
half
>
())));
}
}
};
};
template
<
>
template
<
>
struct
AsinFunctor
<
half
>
{
struct
AsinFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asin
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asin
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
return
__hmul
(
dy
,
hrsqrt
(
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
}
};
};
template
<
>
template
<
>
struct
AsinhFunctor
<
half
>
{
struct
AsinhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asinh
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
asinh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrsqrt
(
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
return
__hmul
(
dy
,
hrsqrt
(
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
}
};
};
template
<
>
template
<
>
struct
AtanFunctor
<
half
>
{
struct
AtanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atan
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atan
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hadd
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
}
};
};
template
<
>
template
<
>
struct
AtanhFunctor
<
half
>
{
struct
AtanhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atanh
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
atanh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
return
__hmul
(
dy
,
__hdiv
(
GetOneVal
<
half
>
(),
__hsub
(
GetOneVal
<
half
>
(),
__hmul
(
x
,
x
))));
}
}
};
};
template
<
>
template
<
>
struct
CeilFunctor
<
half
>
{
struct
CeilFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hceil
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hceil
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
};
template
<
>
template
<
>
struct
NotEqualZeroFunctor
<
half
>
{
struct
NotEqualZeroFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
x
!=
static_cast
<
half
>
(
0.0
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
x
!=
static_cast
<
half
>
(
0.0
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
};
template
<
>
template
<
>
struct
CosFunctor
<
half
>
{
struct
CosFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hcos
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hcos
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hsin
(
x
)));
return
__hmul
(
dy
,
__hneg
(
hsin
(
x
)));
}
}
};
};
template
<
>
template
<
>
struct
CoshFunctor
<
half
>
{
struct
CoshFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
cosh
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
cosh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
MATH_FUNC_H
(
sinh
,
x
));
return
__hmul
(
dy
,
MATH_FUNC_H
(
sinh
,
x
));
}
}
};
};
template
<
>
template
<
>
struct
ErfFunctor
<
half
>
{
struct
ErfFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erf
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erf
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
))));
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
))));
}
}
};
};
template
<
>
template
<
>
struct
ErfcFunctor
<
half
>
{
struct
ErfcFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erfc
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
erfc
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
)))));
return
__hmul
(
dy
,
__hneg
(
__hmul
(
HALF_VAL_2RSQRT_PI
,
hexp
(
__hmul
(
__hneg
(
x
),
x
)))));
}
}
};
};
template
<
>
template
<
>
struct
ExpFunctor
<
half
>
{
struct
ExpFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hexp
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hexp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
};
};
template
<
>
template
<
>
struct
Expm1Functor
<
half
>
{
struct
Expm1Functor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
expm1
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
expm1
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hexp
(
x
));
}
};
};
template
<
>
template
<
>
struct
FloorFunctor
<
half
>
{
struct
FloorFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hfloor
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hfloor
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
};
template
<
>
template
<
>
struct
LgammaFunctor
<
half
>
{
struct
LgammaFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
lgamma
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
lgamma
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
// TODO(chengcheng): return: dy * digamma(x)
// TODO(chengcheng): return: dy * digamma(x)
assert
(
false
);
//
assert(false);
return
GetZeroVal
<
half
>
();
return
GetZeroVal
<
half
>
();
}
}
};
};
template
<
>
template
<
>
struct
LogFunctor
<
half
>
{
struct
LogFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
x
));
}
};
};
template
<
>
template
<
>
struct
Log2Functor
<
half
>
{
struct
Log2Functor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog2
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hlog2
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hmul
(
x
,
hlog
(
HALF_VAL_TWO
))));
return
__hmul
(
dy
,
hrcp
(
__hmul
(
x
,
hlog
(
HALF_VAL_TWO
))));
}
}
};
};
template
<
>
template
<
>
struct
Log1pFunctor
<
half
>
{
struct
Log1pFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
log1p
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
log1p
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hadd
(
x
,
GetOneVal
<
half
>
())));
return
__hmul
(
dy
,
hrcp
(
__hadd
(
x
,
GetOneVal
<
half
>
())));
}
}
};
};
template
<
>
template
<
>
struct
LogSigmoidFunctor
<
half
>
{
struct
LogSigmoidFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
hlog
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
)))));
return
__hneg
(
hlog
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
)))));
}
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hadd
(
hexp
(
x
),
GetOneVal
<
half
>
())));
return
__hmul
(
dy
,
hrcp
(
__hadd
(
hexp
(
x
),
GetOneVal
<
half
>
())));
}
}
};
};
template
<
>
template
<
>
struct
NegativeFunctor
<
half
>
{
struct
NegativeFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hneg
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hneg
(
dy
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hneg
(
dy
);
}
};
};
template
<
>
template
<
>
struct
ReciprocalFunctor
<
half
>
{
struct
ReciprocalFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
}
}
};
};
template
<
>
template
<
>
struct
ReciprocalNoNanFunctor
<
half
>
{
struct
ReciprocalNoNanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
return
hrcp
(
x
);
return
hrcp
(
x
);
}
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
if
(
__heq
(
GetZeroVal
<
half
>
(),
x
))
{
return
GetZeroVal
<
half
>
();
}
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
x
,
x
))));
}
}
};
};
template
<
>
template
<
>
struct
RintFunctor
<
half
>
{
struct
RintFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrint
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrint
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
};
template
<
>
template
<
>
struct
RoundFunctor
<
half
>
{
struct
RoundFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
nearbyint
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
nearbyint
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
};
template
<
>
template
<
>
struct
RsqrtFunctor
<
half
>
{
struct
RsqrtFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
HALF_VAL_TWO
,
hsqrt
(
__hmul
(
x
,
__hmul
(
x
,
x
)))))));
return
__hmul
(
dy
,
__hneg
(
hrcp
(
__hmul
(
HALF_VAL_TWO
,
hsqrt
(
__hmul
(
x
,
__hmul
(
x
,
x
)))))));
}
}
};
};
template
<
>
template
<
>
struct
SigmoidFunctor
<
half
>
{
struct
SigmoidFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
return
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
}
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
half
y
=
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
half
y
=
hrcp
(
__hadd
(
GetOneVal
<
half
>
(),
hexp
(
__hneg
(
x
))));
return
__hmul
(
dy
,
__hmul
(
y
,
__hsub
(
GetOneVal
<
half
>
(),
y
)));
return
__hmul
(
dy
,
__hmul
(
y
,
__hsub
(
GetOneVal
<
half
>
(),
y
)));
}
}
};
};
template
<
>
template
<
>
struct
SignFunctor
<
half
>
{
struct
SignFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
if
(
__hgt
(
x
,
GetZeroVal
<
half
>
()))
{
return
GetOneVal
<
half
>
();
}
if
(
__hgt
(
x
,
GetZeroVal
<
half
>
()))
{
return
GetOneVal
<
half
>
();
}
if
(
__hlt
(
x
,
GetZeroVal
<
half
>
()))
{
return
__hneg
(
GetOneVal
<
half
>
());
}
if
(
__hlt
(
x
,
GetZeroVal
<
half
>
()))
{
return
__hneg
(
GetOneVal
<
half
>
());
}
return
GetZeroVal
<
half
>
();
return
GetZeroVal
<
half
>
();
}
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
GetZeroVal
<
half
>
();
}
};
};
template
<
>
template
<
>
struct
SinFunctor
<
half
>
{
struct
SinFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsin
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsin
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hcos
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hcos
(
x
));
}
};
};
template
<
>
template
<
>
struct
SinhFunctor
<
half
>
{
struct
SinhFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
sinh
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
MATH_FUNC_H
(
sinh
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
MATH_FUNC_H
(
cosh
,
x
));
return
__hmul
(
dy
,
MATH_FUNC_H
(
cosh
,
x
));
}
}
};
};
template
<
>
template
<
>
struct
SqrtFunctor
<
half
>
{
struct
SqrtFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
hsqrt
(
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hdiv
(
HALF_VAL_HALF
,
hsqrt
(
x
)));
return
__hmul
(
dy
,
__hdiv
(
HALF_VAL_HALF
,
hsqrt
(
x
)));
}
}
};
};
template
<
>
template
<
>
struct
SquareFunctor
<
half
>
{
struct
SquareFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hmul
(
x
,
x
);
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hmul
(
x
,
x
);
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_TWO
,
x
));
return
__hmul
(
dy
,
__hmul
(
HALF_VAL_TWO
,
x
));
}
}
};
};
template
<
>
template
<
>
struct
TanFunctor
<
half
>
{
struct
TanFunctor
<
half
>
{
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hdiv
(
hsin
(
x
),
hcos
(
x
));
}
static
OF_HALF_FUNC
half
Forward
(
const
half
x
)
{
return
__hdiv
(
hsin
(
x
),
hcos
(
x
));
}
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
static
OF_HALF_FUNC
half
Backward
(
const
half
x
,
const
half
dy
)
{
return
__hmul
(
dy
,
hrcp
(
__hmul
(
hcos
(
x
),
hcos
(
x
))));
return
__hmul
(
dy
,
hrcp
(
__hmul
(
hcos
(
x
),
hcos
(
x
))));
}
}
};
};
#endif
#endif
}
// namespace oneflow
}
// namespace oneflow
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
oneflow/user/kernels/nvtx_range_kernel.hip.cpp
0 → 100644
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#ifdef OF_ENABLE_PROFILER
#include <roctracer_roctx.h>
#endif // OF_ENABLE_PROFILER
namespace
oneflow
{
namespace
{
#ifdef OF_ENABLE_PROFILER
static
thread_local
HashMap
<
std
::
string
,
roctx_range_id_t
>
mark2range_id
;
#endif
}
// namespace
class
NvtxOpKernelState
final
:
public
user_op
::
OpKernelState
{
public:
NvtxOpKernelState
()
:
counter_
(
0
)
{
#ifndef OF_ENABLE_PROFILER
LOG
(
WARNING
)
<<
"To use NVTX, run cmake with -DBUILD_PROFILER=ON"
;
#endif
}
~
NvtxOpKernelState
()
override
=
default
;
int64_t
counter
()
const
{
return
counter_
;
}
void
IncreaseCount
()
{
counter_
+=
1
;
}
private:
int64_t
counter_
;
};
class
NvtxStartKernel
final
:
public
user_op
::
OpKernel
{
public:
NvtxStartKernel
()
=
default
;
~
NvtxStartKernel
()
override
=
default
;
std
::
shared_ptr
<
user_op
::
OpKernelState
>
CreateOpKernelState
(
user_op
::
KernelInitContext
*
ctx
)
const
override
{
return
std
::
make_shared
<
NvtxOpKernelState
>
();
}
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
)
const
override
{
const
user_op
::
Tensor
*
in
=
ctx
->
Tensor4ArgNameAndIndex
(
"in"
,
0
);
user_op
::
Tensor
*
out
=
ctx
->
Tensor4ArgNameAndIndex
(
"out"
,
0
);
const
ShapeView
&
in_shape
=
in
->
shape_view
();
CHECK_EQ
(
out
->
shape_view
(),
in_shape
);
const
DataType
in_data_type
=
in
->
data_type
();
CHECK_EQ
(
out
->
data_type
(),
in_data_type
);
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
out
->
mut_dptr
<
void
>
(),
in
->
dptr
<
void
>
(),
in_shape
.
elem_cnt
()
*
GetSizeOfDataType
(
in_data_type
));
#ifdef OF_ENABLE_PROFILER
auto
*
kernel_state
=
dynamic_cast
<
NvtxOpKernelState
*>
(
state
);
const
std
::
string
mark_prefix
=
ctx
->
Attr
<
std
::
string
>
(
"mark_prefix"
);
const
std
::
string
mark
=
mark_prefix
+
"-"
+
std
::
to_string
(
kernel_state
->
counter
());
roctx_range_id_t
range_id
=
roctxRangeStartA
(
mark
.
c_str
());
CHECK
(
mark2range_id
.
emplace
(
mark
,
range_id
).
second
);
kernel_state
->
IncreaseCount
();
#endif // OF_ENABLE_PROFILER
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"nvtx_start"
)
.
SetCreateFn
<
NvtxStartKernel
>
()
.
SetIsMatchedHob
(
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
,
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
OF_RETURN_IF_ERROR
(
AddInplaceArgPairFn
(
"out"
,
0
,
"in"
,
0
,
false
));
return
Maybe
<
void
>::
Ok
();
});
class
NvtxEndKernel
final
:
public
user_op
::
OpKernel
{
public:
NvtxEndKernel
()
=
default
;
~
NvtxEndKernel
()
override
=
default
;
std
::
shared_ptr
<
user_op
::
OpKernelState
>
CreateOpKernelState
(
user_op
::
KernelInitContext
*
ctx
)
const
override
{
return
std
::
make_shared
<
NvtxOpKernelState
>
();
}
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
)
const
override
{
const
user_op
::
Tensor
*
in
=
ctx
->
Tensor4ArgNameAndIndex
(
"in"
,
0
);
user_op
::
Tensor
*
out
=
ctx
->
Tensor4ArgNameAndIndex
(
"out"
,
0
);
const
ShapeView
&
in_shape
=
in
->
shape_view
();
CHECK_EQ
(
out
->
shape_view
(),
in_shape
);
const
DataType
in_data_type
=
in
->
data_type
();
CHECK_EQ
(
out
->
data_type
(),
in_data_type
);
#ifdef OF_ENABLE_PROFILER
auto
*
kernel_state
=
dynamic_cast
<
NvtxOpKernelState
*>
(
state
);
const
std
::
string
mark_prefix
=
ctx
->
Attr
<
std
::
string
>
(
"mark_prefix"
);
const
std
::
string
mark
=
mark_prefix
+
"-"
+
std
::
to_string
(
kernel_state
->
counter
());
auto
it
=
mark2range_id
.
find
(
mark
.
c_str
());
CHECK
(
it
!=
mark2range_id
.
end
());
roctx_range_id_t
range_id
=
it
->
second
;
mark2range_id
.
erase
(
it
);
roctxRangeStop
(
range_id
);
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
out
->
mut_dptr
<
void
>
(),
in
->
dptr
<
void
>
(),
in_shape
.
elem_cnt
()
*
GetSizeOfDataType
(
in_data_type
));
kernel_state
->
IncreaseCount
();
#endif
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"nvtx_end"
)
.
SetCreateFn
<
NvtxEndKernel
>
()
.
SetIsMatchedHob
(
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
,
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
OF_RETURN_IF_ERROR
(
AddInplaceArgPairFn
(
"out"
,
0
,
"in"
,
0
,
false
));
return
Maybe
<
void
>::
Ok
();
});
}
// namespace oneflow
oneflow/user/kernels/stateful_opkernel.cpp
View file @
f262efc9
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/eager/call_context.h"
#include "oneflow/core/eager/call_context.h"
namespace
oneflow
{
namespace
oneflow
{
namespace
one
{
namespace
one
{
class
ConsistentTensorInferResult
;
class
ConsistentTensorInferResult
;
using
ArgVec
=
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>
;
using
ArgVec
=
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>
;
using
EagerBlobObjectListRawPtr
=
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>*
;
using
EagerBlobObjectListRawPtr
=
const
std
::
vector
<
std
::
shared_ptr
<
vm
::
EagerBlobObject
>>*
;
using
ConsistentTensorInferResultRawPtr
=
const
ConsistentTensorInferResult
*
;
using
ConsistentTensorInferResultRawPtr
=
const
ConsistentTensorInferResult
*
;
class
ZeroCopyBaseContextHelper
{
class
ZeroCopyBaseContextHelper
{
public:
public:
ZeroCopyBaseContextHelper
(
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
ZeroCopyBaseContextHelper
(
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
input_arg_tuple_
(
input_arg_tuple
),
output_arg_tuple_
(
output_arg_tuple
)
{}
:
input_arg_tuple_
(
input_arg_tuple
),
output_arg_tuple_
(
output_arg_tuple
)
{}
#define RETURN_IF_FOUND(inputs, outputs, post_action) \
#define RETURN_IF_FOUND(inputs, outputs, post_action) \
int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \
int32_t
i
=
TryGetTensorTupleIndex
(
input_arg_tuple_
->
arg_name2bn_index2tensor_tuple_index
(),
\
arg_name, index); \
arg_name
,
index
);
\
if (i >= 0) { return (inputs).at(i) post_action; } \
if
(
i
>=
0
)
{
return
(
inputs
).
at
(
i
)
post_action
;
}
\
i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \
i
=
TryGetTensorTupleIndex
(
output_arg_tuple_
->
arg_name2bn_index2tensor_tuple_index
(),
arg_name
,
\
index); \
index
);
\
if (i >= 0) { return (outputs).at(i) post_action; }
if
(
i
>=
0
)
{
return
(
outputs
).
at
(
i
)
post_action
;
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
const
int32_t
index
)
const
{
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
return
nullptr
;
return
nullptr
;
}
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
const
int32_t
index
)
const
{
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
RETURN_IF_FOUND
(
*
call_ctx
->
inputs
(),
*
call_ctx
->
outputs
(),
.
get
());
if
(
arg_name
==
"tmp_buffer"
&&
index
==
0
)
{
return
call_ctx
->
mut_tmp_tensor
();
}
if
(
arg_name
==
"tmp_buffer"
&&
index
==
0
)
{
return
call_ctx
->
mut_tmp_tensor
();
}
return
nullptr
;
return
nullptr
;
}
}
const
ConsistentTensorMeta
*
ConsistentTensorMeta4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
ConsistentTensorMeta
*
ConsistentTensorMeta4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
const
int32_t
index
)
const
{
const
int32_t
index
)
const
{
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
RETURN_IF_FOUND
(
consistent_tensor_infer_result
->
input_tensor_metas
(),
RETURN_IF_FOUND
(
consistent_tensor_infer_result
->
input_tensor_metas
(),
consistent_tensor_infer_result
->
output_tensor_metas
(),
consistent_tensor_infer_result
->
output_tensor_metas
(),
.
shared_from_symbol
().
get
());
.
shared_from_symbol
().
get
());
return
nullptr
;
return
nullptr
;
}
}
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
Optional
<
Symbol
<
ParallelDesc
>>
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
const
auto
&
consistent_tensor_infer_result
=
call_ctx
->
consistent_tensor_infer_result
();
if
(
!
consistent_tensor_infer_result
)
{
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
if
(
!
consistent_tensor_infer_result
)
{
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
if
(
!
consistent_tensor_infer_result
->
input_tensor_metas
().
empty
())
{
if
(
!
consistent_tensor_infer_result
->
input_tensor_metas
().
empty
())
{
return
consistent_tensor_infer_result
->
input_tensor_metas
().
at
(
0
)
->
parallel_desc
();
return
consistent_tensor_infer_result
->
input_tensor_metas
().
at
(
0
)
->
parallel_desc
();
}
else
if
(
!
consistent_tensor_infer_result
->
output_tensor_metas
().
empty
())
{
}
else
if
(
!
consistent_tensor_infer_result
->
output_tensor_metas
().
empty
())
{
return
consistent_tensor_infer_result
->
output_tensor_metas
().
at
(
0
)
->
parallel_desc
();
return
consistent_tensor_infer_result
->
output_tensor_metas
().
at
(
0
)
->
parallel_desc
();
}
else
{
}
else
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
return
Optional
<
Symbol
<
ParallelDesc
>>
();
return
Optional
<
Symbol
<
ParallelDesc
>>
();
}
}
}
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
auto
&
parallel_desc
=
this
->
parallel_desc
(
call_ctx
);
const
auto
&
parallel_desc
=
this
->
parallel_desc
(
call_ctx
);
if
(
parallel_desc
.
has_value
())
{
if
(
parallel_desc
.
has_value
())
{
const
auto
&
parallel_desc_symbol
=
CHECK_JUST
(
parallel_desc
);
const
auto
&
parallel_desc_symbol
=
CHECK_JUST
(
parallel_desc
);
return
*
CHECK_JUST
(
GetParallelContext4CurrentProcessCtx
(
parallel_desc_symbol
));
return
*
CHECK_JUST
(
GetParallelContext4CurrentProcessCtx
(
parallel_desc_symbol
));
}
else
{
}
else
{
static
ParallelContext
single_device_parallel_ctx
(
MakeSingleDeviceParallelCtx
());
static
ParallelContext
single_device_parallel_ctx
(
MakeSingleDeviceParallelCtx
());
return
single_device_parallel_ctx
;
return
single_device_parallel_ctx
;
}
}
}
}
const
ArgVec
&
inputs
()
const
{
return
input_arg_tuple_
->
indexed_arg_name_and_index
();
}
const
ArgVec
&
inputs
()
const
{
return
input_arg_tuple_
->
indexed_arg_name_and_index
();
}
const
ArgVec
&
outputs
()
const
{
return
output_arg_tuple_
->
indexed_arg_name_and_index
();
}
const
ArgVec
&
outputs
()
const
{
return
output_arg_tuple_
->
indexed_arg_name_and_index
();
}
private:
private:
static
int32_t
TryGetTensorTupleIndex
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int32_t
>>&
static
int32_t
TryGetTensorTupleIndex
(
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int32_t
>>&
arg_name2bn_index2tensor_tuple_index
,
arg_name2bn_index2tensor_tuple_index
,
const
std
::
string
&
arg_name
,
const
int32_t
arg_index
)
{
const
std
::
string
&
arg_name
,
const
int32_t
arg_index
)
{
auto
it
=
arg_name2bn_index2tensor_tuple_index
.
find
(
arg_name
);
auto
it
=
arg_name2bn_index2tensor_tuple_index
.
find
(
arg_name
);
if
(
it
!=
arg_name2bn_index2tensor_tuple_index
.
end
())
{
return
it
->
second
.
at
(
arg_index
);
}
if
(
it
!=
arg_name2bn_index2tensor_tuple_index
.
end
())
{
return
it
->
second
.
at
(
arg_index
);
}
return
-
1
;
return
-
1
;
}
}
static
ParallelContext
MakeSingleDeviceParallelCtx
()
{
static
ParallelContext
MakeSingleDeviceParallelCtx
()
{
ParallelContext
single_device_parallel_ctx
;
ParallelContext
single_device_parallel_ctx
;
single_device_parallel_ctx
.
set_parallel_id
(
0
);
single_device_parallel_ctx
.
set_parallel_id
(
0
);
single_device_parallel_ctx
.
set_parallel_num
(
1
);
single_device_parallel_ctx
.
set_parallel_num
(
1
);
return
single_device_parallel_ctx
;
return
single_device_parallel_ctx
;
}
}
std
::
shared_ptr
<
const
ArgTuple
>
input_arg_tuple_
;
std
::
shared_ptr
<
const
ArgTuple
>
input_arg_tuple_
;
std
::
shared_ptr
<
const
ArgTuple
>
output_arg_tuple_
;
std
::
shared_ptr
<
const
ArgTuple
>
output_arg_tuple_
;
};
};
class
UserKernelBaseContextHelper
final
:
public
ZeroCopyBaseContextHelper
{
class
UserKernelBaseContextHelper
final
:
public
ZeroCopyBaseContextHelper
{
public:
public:
UserKernelBaseContextHelper
(
DeviceType
device_type
,
UserKernelBaseContextHelper
(
DeviceType
device_type
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
ZeroCopyBaseContextHelper
(
input_arg_tuple
,
output_arg_tuple
),
device_type_
(
device_type
)
{}
:
ZeroCopyBaseContextHelper
(
input_arg_tuple
,
output_arg_tuple
),
device_type_
(
device_type
)
{}
~
UserKernelBaseContextHelper
()
=
default
;
~
UserKernelBaseContextHelper
()
=
default
;
DeviceType
device_type
()
const
{
return
device_type_
;
}
DeviceType
device_type
()
const
{
return
device_type_
;
}
const
JobDesc
&
job_desc
()
const
{
const
JobDesc
&
job_desc
()
const
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
return
*
(
const
JobDesc
*
)
nullptr
;
return
*
(
const
JobDesc
*
)
nullptr
;
}
}
private:
private:
const
DeviceType
device_type_
;
const
DeviceType
device_type_
;
};
};
class
UserOpInferContextHelper
final
{
class
UserOpInferContextHelper
final
{
public:
public:
UserOpInferContextHelper
(
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
UserOpInferContextHelper
(
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
:
user_op_conf_
(
user_op_conf
),
zero_copy_base_ctx_helper_
(
input_arg_tuple
,
output_arg_tuple
)
{}
zero_copy_base_ctx_helper_
(
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserOpInferContextHelper
()
=
default
;
~
UserOpInferContextHelper
()
=
default
;
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
return
nullptr
;
return
nullptr
;
}
}
const
user_op
::
TensorDesc
&
InputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
TensorDesc
&
InputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
));
return
*
CHECK_NOTNULL
(
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
));
}
}
user_op
::
TensorDesc
*
OutputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
user_op
::
TensorDesc
*
OutputTensorDesc
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
zero_copy_base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
zero_copy_base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
const
Shape
&
InputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
Shape
&
InputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
*
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
*
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
Shape
*
OutputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
Shape
*
OutputShape
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
Shape4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
Shape
*
Shape4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
Shape
*
Shape4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_shape
();
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_shape
();
}
}
const
Stride
&
InputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
Stride
&
InputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
*
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
*
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
Stride
*
OutputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
Stride
*
OutputStride
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
Stride4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
Stride
*
Stride4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
Stride
*
Stride4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_stride
();
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_stride
();
}
}
const
DataType
&
InputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
DataType
&
InputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
*
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
*
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
DataType
*
OutputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
DataType
*
OutputDType
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
Dtype4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
DataType
*
Dtype4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
DataType
*
Dtype4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_data_type
();
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_data_type
();
}
}
bool
InputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
bool
InputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
*
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
*
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
bool
*
OutputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
bool
*
OutputIsDynamic
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
IsDynamic4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
bool
*
IsDynamic4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
bool
*
IsDynamic4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_is_dynamic
();
return
NonNullTensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
)
->
mut_is_dynamic
();
}
}
const
ArgVec
&
inputs
()
const
{
return
zero_copy_base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
inputs
()
const
{
return
zero_copy_base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
zero_copy_base_ctx_helper_
.
outputs
();
}
const
ArgVec
&
outputs
()
const
{
return
zero_copy_base_ctx_helper_
.
outputs
();
}
const
JobDesc
*
job_desc
()
const
{
const
JobDesc
*
job_desc
()
const
{
UNIMPLEMENTED
();
UNIMPLEMENTED
();
return
nullptr
;
return
nullptr
;
}
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
zero_copy_base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
return
zero_copy_base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
}
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
return
*
CHECK_JUST
(
zero_copy_base_ctx_helper_
.
parallel_desc
(
call_ctx
));
return
*
CHECK_JUST
(
zero_copy_base_ctx_helper_
.
parallel_desc
(
call_ctx
));
}
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
return
nd_sbp
.
sbp_parallel
(
0
);
return
nd_sbp
.
sbp_parallel
(
0
);
}
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
zero_copy_base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
return
*
CHECK_NOTNULL
(
zero_copy_base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
call_ctx
,
arg_name
,
index
))
->
nd_sbp
();
->
nd_sbp
();
}
}
int64_t
parallel_num
(
eager
::
CallContext
*
call_ctx
)
const
{
int64_t
parallel_num
(
eager
::
CallContext
*
call_ctx
)
const
{
return
parallel_ctx
(
call_ctx
).
parallel_num
();
return
parallel_ctx
(
call_ctx
).
parallel_num
();
}
}
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
input
(
arg_name
,
index
);
return
user_op_conf
().
input
(
arg_name
,
index
);
}
}
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
output
(
arg_name
,
index
);
return
user_op_conf
().
output
(
arg_name
,
index
);
}
}
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
has_input
(
arg_name
,
index
);
return
user_op_conf
().
has_input
(
arg_name
,
index
);
}
}
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
return
user_op_conf
().
has_output
(
arg_name
,
index
);
return
user_op_conf
().
has_output
(
arg_name
,
index
);
}
}
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
{
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
{
return
user_op_conf
().
input_size
(
arg_name
);
return
user_op_conf
().
input_size
(
arg_name
);
}
}
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
{
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
{
return
user_op_conf
().
output_size
(
arg_name
);
return
user_op_conf
().
output_size
(
arg_name
);
}
}
const
std
::
string
&
op_name
()
const
{
return
user_op_conf
().
op_name
();
}
const
std
::
string
&
op_name
()
const
{
return
user_op_conf
().
op_name
();
}
const
std
::
string
&
op_type_name
()
const
{
return
user_op_conf
().
op_type_name
();
}
const
std
::
string
&
op_type_name
()
const
{
return
user_op_conf
().
op_type_name
();
}
const
std
::
string
&
op_loc
()
const
{
return
user_op_conf_
->
op_conf
().
loc
();
}
const
std
::
string
&
op_loc
()
const
{
return
user_op_conf_
->
op_conf
().
loc
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
}
private:
private:
user_op
::
TensorDesc
*
NonNullTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
user_op
::
TensorDesc
*
NonNullTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
user_op
::
TensorDesc
*
tensor_desc
=
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
user_op
::
TensorDesc
*
tensor_desc
=
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
if
(
!
tensor_desc
)
{
LOG
(
FATAL
)
<<
"Arg ("
<<
arg_name
<<
","
<<
index
<<
") is not found"
;
}
if
(
!
tensor_desc
)
{
LOG
(
FATAL
)
<<
"Arg ("
<<
arg_name
<<
","
<<
index
<<
") is not found"
;
}
return
tensor_desc
;
return
tensor_desc
;
}
}
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
ZeroCopyBaseContextHelper
zero_copy_base_ctx_helper_
;
ZeroCopyBaseContextHelper
zero_copy_base_ctx_helper_
;
};
};
class
UserOpInferContext
:
public
user_op
::
InferContext
{
class
UserOpInferContext
:
public
user_op
::
InferContext
{
public:
public:
UserOpInferContext
(
const
UserOpInferContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
UserOpInferContext
(
const
UserOpInferContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
~
UserOpInferContext
()
override
=
default
;
~
UserOpInferContext
()
override
=
default
;
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
user_op
::
TensorDesc
&
InputTensorDesc
(
const
std
::
string
&
arg_name
,
const
user_op
::
TensorDesc
&
InputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
InputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
InputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
}
}
user_op
::
TensorDesc
*
OutputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
user_op
::
TensorDesc
*
OutputTensorDesc
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
OutputTensorDesc
(
call_ctx_
,
arg_name
,
index
);
}
}
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
{
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
Shape
&
InputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
Shape
&
InputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputShape
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
InputShape
(
call_ctx_
,
arg_name
,
index
);
}
}
Shape
*
OutputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
Shape
*
OutputShape
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputShape
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
OutputShape
(
call_ctx_
,
arg_name
,
index
);
}
}
Shape
*
Shape4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
Shape
*
Shape4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Shape4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
Shape4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
Stride
&
InputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
Stride
&
InputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputStride
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
InputStride
(
call_ctx_
,
arg_name
,
index
);
}
}
Stride
*
OutputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
Stride
*
OutputStride
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputStride
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
OutputStride
(
call_ctx_
,
arg_name
,
index
);
}
}
Stride
*
Stride4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
Stride
*
Stride4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Stride4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
Stride4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
DataType
&
InputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
DataType
&
InputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputDType
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
InputDType
(
call_ctx_
,
arg_name
,
index
);
}
}
DataType
*
OutputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
DataType
*
OutputDType
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputDType
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
OutputDType
(
call_ctx_
,
arg_name
,
index
);
}
}
DataType
*
Dtype4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
DataType
*
Dtype4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Dtype4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
Dtype4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
bool
InputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
bool
InputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
InputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
InputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
}
}
bool
*
OutputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
bool
*
OutputIsDynamic
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
OutputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
OutputIsDynamic
(
call_ctx_
,
arg_name
,
index
);
}
}
bool
*
IsDynamic4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
bool
*
IsDynamic4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
IsDynamic4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
IsDynamic4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
JobDesc
*
job_desc
()
const
override
{
return
helper_
->
job_desc
();
}
const
JobDesc
*
job_desc
()
const
override
{
return
helper_
->
job_desc
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
int64_t
parallel_num
()
const
override
{
return
helper_
->
parallel_num
(
call_ctx_
);
}
int64_t
parallel_num
()
const
override
{
return
helper_
->
parallel_num
(
call_ctx_
);
}
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
std
::
string
&
input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
input
(
arg_name
,
index
);
return
helper_
->
input
(
arg_name
,
index
);
}
}
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
std
::
string
&
output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
output
(
arg_name
,
index
);
return
helper_
->
output
(
arg_name
,
index
);
}
}
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
bool
has_input
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
has_input
(
arg_name
,
index
);
return
helper_
->
has_input
(
arg_name
,
index
);
}
}
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
bool
has_output
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
has_output
(
arg_name
,
index
);
return
helper_
->
has_output
(
arg_name
,
index
);
}
}
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
override
{
int32_t
input_size
(
const
std
::
string
&
arg_name
)
const
override
{
return
helper_
->
input_size
(
arg_name
);
return
helper_
->
input_size
(
arg_name
);
}
}
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
override
{
int32_t
output_size
(
const
std
::
string
&
arg_name
)
const
override
{
return
helper_
->
output_size
(
arg_name
);
return
helper_
->
output_size
(
arg_name
);
}
}
const
std
::
string
&
op_name
()
const
override
{
return
helper_
->
op_name
();
}
const
std
::
string
&
op_name
()
const
override
{
return
helper_
->
op_name
();
}
const
std
::
string
&
op_type_name
()
const
override
{
return
helper_
->
op_type_name
();
}
const
std
::
string
&
op_type_name
()
const
override
{
return
helper_
->
op_type_name
();
}
const
std
::
string
&
op_loc
()
const
override
{
return
helper_
->
op_loc
();
}
const
std
::
string
&
op_loc
()
const
override
{
return
helper_
->
op_loc
();
}
private:
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
}
const
UserOpInferContextHelper
*
helper_
;
const
UserOpInferContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
eager
::
CallContext
*
call_ctx_
;
};
};
class
UserKernelComputeContextHelper
final
{
class
UserKernelComputeContextHelper
final
{
public:
public:
UserKernelComputeContextHelper
(
DeviceType
device_type
,
UserKernelComputeContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelComputeContextHelper
()
=
default
;
~
UserKernelComputeContextHelper
()
=
default
;
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
base_ctx_helper_
.
Tensor4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
base_ctx_helper_
.
Tensor4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
CHECK
(
device_ctx
);
CHECK
(
device_ctx
);
return
device_ctx
->
stream
();
return
device_ctx
->
stream
();
}
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
}
private:
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
};
class
UserKernelComputeContext
final
:
public
user_op
::
KernelComputeContext
{
class
UserKernelComputeContext
final
:
public
user_op
::
KernelComputeContext
{
public:
public:
UserKernelComputeContext
(
const
UserKernelComputeContextHelper
*
helper
,
UserKernelComputeContext
(
const
UserKernelComputeContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
~
UserKernelComputeContext
()
=
default
;
~
UserKernelComputeContext
()
=
default
;
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
user_op
::
Tensor
*
Tensor4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
override
{
return
helper_
->
Tensor4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
Tensor4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
private:
private:
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
return
helper_
->
user_op_conf
();
}
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
}
const
UserKernelComputeContextHelper
*
helper_
;
const
UserKernelComputeContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
eager
::
CallContext
*
call_ctx_
;
DeviceCtx
*
device_ctx_
;
DeviceCtx
*
device_ctx_
;
};
};
class
UserKernelRegContextHelper
final
{
class
UserKernelRegContextHelper
final
{
public:
public:
UserKernelRegContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
UserKernelRegContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelRegContextHelper
()
=
default
;
~
UserKernelRegContextHelper
()
=
default
;
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
}
private:
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
};
class
UserKernelRegContext
final
:
public
user_op
::
KernelRegContext
{
class
UserKernelRegContext
final
:
public
user_op
::
KernelRegContext
{
public:
public:
UserKernelRegContext
(
const
UserKernelRegContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
UserKernelRegContext
(
const
UserKernelRegContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
:
helper_
(
helper
),
call_ctx_
(
call_ctx
)
{}
~
UserKernelRegContext
()
=
default
;
~
UserKernelRegContext
()
=
default
;
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
return
helper_
->
user_op_conf
();
}
}
private:
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
}
const
UserKernelRegContextHelper
*
helper_
;
const
UserKernelRegContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
eager
::
CallContext
*
call_ctx_
;
};
};
class
UserKernelInitAndCacheContextHelper
final
{
class
UserKernelInitAndCacheContextHelper
final
{
public:
public:
UserKernelInitAndCacheContextHelper
(
DeviceType
device_type
,
UserKernelInitAndCacheContextHelper
(
DeviceType
device_type
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
user_op
::
UserOpConfWrapper
*
user_op_conf
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
:
user_op_conf_
(
user_op_conf
),
:
user_op_conf_
(
user_op_conf
),
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
base_ctx_helper_
(
device_type
,
input_arg_tuple
,
output_arg_tuple
)
{}
~
UserKernelInitAndCacheContextHelper
()
=
default
;
~
UserKernelInitAndCacheContextHelper
()
=
default
;
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
ep
::
Stream
*
stream
(
DeviceCtx
*
device_ctx
)
const
{
CHECK
(
device_ctx
);
CHECK
(
device_ctx
);
return
device_ctx
->
stream
();
return
device_ctx
->
stream
();
}
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
DeviceType
device_type
()
const
{
return
base_ctx_helper_
.
device_type
();
}
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelContext
&
parallel_ctx
(
eager
::
CallContext
*
call_ctx
)
const
{
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
return
base_ctx_helper_
.
parallel_ctx
(
call_ctx
);
}
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
base_ctx_helper_
.
TensorDesc4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
return
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
}
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
const
auto
&
nd_sbp
=
NdSbp4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
CHECK_EQ
(
nd_sbp
.
sbp_parallel_size
(),
1
);
return
nd_sbp
.
sbp_parallel
(
0
);
return
nd_sbp
.
sbp_parallel
(
0
);
}
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
arg_name
,
int32_t
index
)
const
{
int32_t
index
)
const
{
return
*
CHECK_NOTNULL
(
return
*
CHECK_NOTNULL
(
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
base_ctx_helper_
.
ConsistentTensorMeta4ArgNameAndIndex
(
call_ctx
,
arg_name
,
index
))
->
nd_sbp
();
->
nd_sbp
();
}
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
inputs
()
const
{
return
base_ctx_helper_
.
inputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
ArgVec
&
outputs
()
const
{
return
base_ctx_helper_
.
outputs
();
}
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
const
ParallelDesc
&
parallel_desc
(
eager
::
CallContext
*
call_ctx
)
const
{
return
*
CHECK_JUST
(
base_ctx_helper_
.
parallel_desc
(
call_ctx
));
return
*
CHECK_JUST
(
base_ctx_helper_
.
parallel_desc
(
call_ctx
));
}
}
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
eager
::
CallContext
*
call_ctx
,
const
std
::
string
&
attr_name
)
const
{
const
std
::
string
&
attr_name
)
const
{
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
return
call_ctx
->
composed_attrs
().
Attr4Name
(
attr_name
);
}
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
{
return
*
user_op_conf_
;
}
private:
private:
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
const
user_op
::
UserOpConfWrapper
*
user_op_conf_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
UserKernelBaseContextHelper
base_ctx_helper_
;
};
};
class
UserKernelInitAndCacheContext
final
:
public
user_op
::
KernelInitContext
,
class
UserKernelInitAndCacheContext
final
:
public
user_op
::
KernelInitContext
,
public
user_op
::
KernelCacheContext
{
public
user_op
::
KernelCacheContext
{
public:
public:
UserKernelInitAndCacheContext
(
const
UserKernelInitAndCacheContextHelper
*
helper
,
UserKernelInitAndCacheContext
(
const
UserKernelInitAndCacheContextHelper
*
helper
,
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
)
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
:
helper_
(
helper
),
call_ctx_
(
call_ctx
),
device_ctx_
(
device_ctx
)
{}
~
UserKernelInitAndCacheContext
()
override
=
default
;
~
UserKernelInitAndCacheContext
()
override
=
default
;
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
ep
::
Stream
*
stream
()
override
{
return
helper_
->
stream
(
device_ctx_
);
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
DeviceType
device_type
()
const
override
{
return
helper_
->
device_type
();
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
ParallelContext
&
parallel_ctx
()
const
override
{
return
helper_
->
parallel_ctx
(
call_ctx_
);
}
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
user_op
::
TensorDesc
*
TensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
TensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
user_op
::
TensorDesc
*
LogicalTensorDesc4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
LogicalTensorDesc4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
const
SbpParallel
&
SbpParallel4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
int32_t
index
)
const
override
{
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
SbpParallel4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
const
NdSbp
&
NdSbp4ArgNameAndIndex
(
const
std
::
string
&
arg_name
,
int32_t
index
)
const
override
{
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
return
helper_
->
NdSbp4ArgNameAndIndex
(
call_ctx_
,
arg_name
,
index
);
}
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
inputs
()
const
override
{
return
helper_
->
inputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ArgVec
&
outputs
()
const
override
{
return
helper_
->
outputs
();
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
const
ParallelDesc
&
parallel_desc
()
const
override
{
return
helper_
->
parallel_desc
(
call_ctx_
);
}
private:
private:
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
shared_ptr
<
const
user_op
::
AttrVal
>&
Attr4Name
(
const
std
::
string
&
attr_name
)
const
override
{
const
std
::
string
&
attr_name
)
const
override
{
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
return
helper_
->
Attr4Name
(
call_ctx_
,
attr_name
);
}
}
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
const
user_op
::
UserOpConfWrapper
&
user_op_conf
()
const
override
{
return
helper_
->
user_op_conf
();
return
helper_
->
user_op_conf
();
}
}
const
UserKernelInitAndCacheContextHelper
*
helper_
;
const
UserKernelInitAndCacheContextHelper
*
helper_
;
eager
::
CallContext
*
call_ctx_
;
eager
::
CallContext
*
call_ctx_
;
DeviceCtx
*
device_ctx_
;
DeviceCtx
*
device_ctx_
;
};
};
namespace
{
namespace
{
Maybe
<
void
>
InitTensorTupleIndexes4Bns
(
const
std
::
shared_ptr
<
const
OperatorConf
>&
op_conf
,
Maybe
<
void
>
InitTensorTupleIndexes4Bns
(
const
std
::
shared_ptr
<
const
OperatorConf
>&
op_conf
,
const
ArgVec
&
indexed_input_pairs
,
const
ArgVec
&
indexed_input_pairs
,
const
ArgVec
&
indexed_output_pairs
,
const
ArgVec
&
indexed_output_pairs
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4const_ibns
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4const_ibns
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4mut_ibns
,
std
::
vector
<
int64_t
>*
input_tuple_indexes4mut_ibns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut_obns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut_obns
,
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut2_obns
)
{
std
::
vector
<
int64_t
>*
output_tuple_indexes4mut2_obns
)
{
const
auto
*
op_reg_val
=
const
auto
*
op_reg_val
=
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
op_conf
->
user_conf
().
op_type_name
());
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
op_conf
->
user_conf
().
op_type_name
());
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
ArgModifierSignature
arg_modifier_signature
;
ArgModifierSignature
arg_modifier_signature
;
for
(
const
auto
&
pair
:
indexed_input_pairs
)
{
for
(
const
auto
&
pair
:
indexed_input_pairs
)
{
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
()
->
insert
(
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
()
->
insert
(
{
ibn
,
user_op
::
InputArgModifier
()});
{
ibn
,
user_op
::
InputArgModifier
()});
}
}
for
(
const
auto
&
pair
:
indexed_output_pairs
)
{
for
(
const
auto
&
pair
:
indexed_output_pairs
)
{
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
arg_modifier_signature
.
mutable_obn2output_blob_modifier
()
->
insert
(
arg_modifier_signature
.
mutable_obn2output_blob_modifier
()
->
insert
(
{
obn
,
user_op
::
OutputArgModifier
()});
{
obn
,
user_op
::
OutputArgModifier
()});
}
}
user_op
::
UserOpConfWrapper
op_conf_wrapper
(
op_conf
);
user_op
::
UserOpConfWrapper
op_conf_wrapper
(
op_conf
);
if
(
op_reg_val
->
input_arg_modify_fn
)
{
if
(
op_reg_val
->
input_arg_modify_fn
)
{
user_op
::
GetInputArgModifier
GetInputArgModifierFn
=
user_op
::
GetInputArgModifier
GetInputArgModifierFn
=
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
int32_t
in_arg_index
)
->
user_op
::
InputArgModifier
*
{
int32_t
in_arg_index
)
->
user_op
::
InputArgModifier
*
{
const
std
::
string
ibn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
const
std
::
string
ibn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
auto
*
map
=
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
();
auto
*
map
=
arg_modifier_signature
.
mutable_ibn2input_blob_modifier
();
return
&
map
->
at
(
ibn
);
return
&
map
->
at
(
ibn
);
};
};
JUST
(
op_reg_val
->
input_arg_modify_fn
(
GetInputArgModifierFn
,
op_conf_wrapper
));
JUST
(
op_reg_val
->
input_arg_modify_fn
(
GetInputArgModifierFn
,
op_conf_wrapper
));
}
}
if
(
op_reg_val
->
output_arg_modify_fn
)
{
if
(
op_reg_val
->
output_arg_modify_fn
)
{
user_op
::
GetOutputArgModifier
GetOutputArgModifierFn
=
user_op
::
GetOutputArgModifier
GetOutputArgModifierFn
=
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
[
&
arg_modifier_signature
](
const
std
::
string
&
in_arg_name
,
int32_t
in_arg_index
)
->
user_op
::
OutputArgModifier
*
{
int32_t
in_arg_index
)
->
user_op
::
OutputArgModifier
*
{
const
std
::
string
obn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
const
std
::
string
obn
=
GenRepeatedBn
(
in_arg_name
,
in_arg_index
);
auto
*
map
=
arg_modifier_signature
.
mutable_obn2output_blob_modifier
();
auto
*
map
=
arg_modifier_signature
.
mutable_obn2output_blob_modifier
();
return
&
map
->
at
(
obn
);
return
&
map
->
at
(
obn
);
};
};
JUST
(
op_reg_val
->
output_arg_modify_fn
(
GetOutputArgModifierFn
,
op_conf_wrapper
));
JUST
(
op_reg_val
->
output_arg_modify_fn
(
GetOutputArgModifierFn
,
op_conf_wrapper
));
}
}
for
(
int
i
=
0
;
i
<
indexed_input_pairs
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
indexed_input_pairs
.
size
();
i
++
)
{
const
auto
&
pair
=
indexed_input_pairs
.
at
(
i
);
const
auto
&
pair
=
indexed_input_pairs
.
at
(
i
);
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
const
std
::
string
ibn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
if
(
arg_modifier_signature
.
ibn2input_blob_modifier
().
at
(
ibn
).
is_mutable
())
{
if
(
arg_modifier_signature
.
ibn2input_blob_modifier
().
at
(
ibn
).
is_mutable
())
{
input_tuple_indexes4mut_ibns
->
emplace_back
(
i
);
input_tuple_indexes4mut_ibns
->
emplace_back
(
i
);
}
else
{
}
else
{
input_tuple_indexes4const_ibns
->
emplace_back
(
i
);
input_tuple_indexes4const_ibns
->
emplace_back
(
i
);
}
}
}
}
for
(
int
i
=
0
;
i
<
indexed_output_pairs
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
indexed_output_pairs
.
size
();
i
++
)
{
const
auto
&
pair
=
indexed_output_pairs
.
at
(
i
);
const
auto
&
pair
=
indexed_output_pairs
.
at
(
i
);
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
const
std
::
string
obn
=
GenRepeatedBn
(
pair
.
first
,
pair
.
second
);
if
(
arg_modifier_signature
.
obn2output_blob_modifier
().
at
(
obn
).
header_infered_before_compute
())
{
if
(
arg_modifier_signature
.
obn2output_blob_modifier
().
at
(
obn
).
header_infered_before_compute
())
{
output_tuple_indexes4mut_obns
->
emplace_back
(
i
);
output_tuple_indexes4mut_obns
->
emplace_back
(
i
);
}
else
{
}
else
{
output_tuple_indexes4mut2_obns
->
emplace_back
(
i
);
output_tuple_indexes4mut2_obns
->
emplace_back
(
i
);
}
}
}
}
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
}
// namespace
}
// namespace
/* static */
Maybe
<
StatefulOpKernel
>
StatefulOpKernel
::
New
(
/* static */
Maybe
<
StatefulOpKernel
>
StatefulOpKernel
::
New
(
const
std
::
shared_ptr
<
OperatorConf
>&
op_conf
,
const
Symbol
<
Stream
>&
stream
,
const
std
::
shared_ptr
<
OperatorConf
>&
op_conf
,
const
Symbol
<
Stream
>&
stream
,
const
AttrMap
&
base_attrs
,
const
std
::
shared_ptr
<
const
ParallelDesc
>&
parallel_desc
,
const
AttrMap
&
base_attrs
,
const
std
::
shared_ptr
<
const
ParallelDesc
>&
parallel_desc
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
input_arg_tuple
,
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
{
const
std
::
shared_ptr
<
const
ArgTuple
>&
output_arg_tuple
)
{
auto
opkernel
=
std
::
shared_ptr
<
StatefulOpKernel
>
(
new
StatefulOpKernel
());
auto
opkernel
=
std
::
shared_ptr
<
StatefulOpKernel
>
(
new
StatefulOpKernel
());
opkernel
->
base_attrs_
=
base_attrs
;
opkernel
->
base_attrs_
=
base_attrs
;
opkernel
->
op_conf_
=
op_conf
;
opkernel
->
op_conf_
=
op_conf
;
opkernel
->
user_op_conf_
.
reset
(
new
user_op
::
UserOpConfWrapper
(
op_conf
));
opkernel
->
user_op_conf_
.
reset
(
new
user_op
::
UserOpConfWrapper
(
op_conf
));
opkernel
->
stream_
=
stream
;
opkernel
->
stream_
=
stream
;
opkernel
->
input_arg_tuple_
=
input_arg_tuple
;
opkernel
->
input_arg_tuple_
=
input_arg_tuple
;
opkernel
->
output_arg_tuple_
=
output_arg_tuple
;
opkernel
->
output_arg_tuple_
=
output_arg_tuple
;
opkernel
->
need_check_mem_case_
=
true
;
opkernel
->
need_check_mem_case_
=
true
;
const
DeviceType
device_type
=
CHECK_JUST
(
DeviceType4DeviceTag
(
op_conf
->
device_tag
()));
const
DeviceType
device_type
=
CHECK_JUST
(
DeviceType4DeviceTag
(
op_conf
->
device_tag
()));
const
user_op
::
UserOpConfWrapper
*
user_op_conf
=
opkernel
->
user_op_conf_
.
get
();
const
user_op
::
UserOpConfWrapper
*
user_op_conf
=
opkernel
->
user_op_conf_
.
get
();
opkernel
->
op_infer_ctx_helper_
.
reset
(
opkernel
->
op_infer_ctx_helper_
.
reset
(
new
UserOpInferContextHelper
(
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
new
UserOpInferContextHelper
(
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
opkernel
->
init_and_cache_ctx_helper_
.
reset
(
new
UserKernelInitAndCacheContextHelper
(
opkernel
->
init_and_cache_ctx_helper_
.
reset
(
new
UserKernelInitAndCacheContextHelper
(
device_type
,
opkernel
->
user_op_conf_
.
get
(),
opkernel
->
input_arg_tuple_
,
device_type
,
opkernel
->
user_op_conf_
.
get
(),
opkernel
->
input_arg_tuple_
,
opkernel
->
output_arg_tuple_
));
opkernel
->
output_arg_tuple_
));
opkernel
->
compute_ctx_helper_
.
reset
(
new
UserKernelComputeContextHelper
(
opkernel
->
compute_ctx_helper_
.
reset
(
new
UserKernelComputeContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
opkernel
->
reg_ctx_helper_
.
reset
(
opkernel
->
reg_ctx_helper_
.
reset
(
new
UserKernelRegContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
new
UserKernelRegContextHelper
(
device_type
,
user_op_conf
,
input_arg_tuple
,
output_arg_tuple
));
const
auto
*
op_reg_val
=
const
auto
*
op_reg_val
=
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
user_op_conf
->
op_type_name
());
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpRegistryResult
(
user_op_conf
->
op_type_name
());
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
CHECK_NOTNULL_OR_RETURN
(
op_reg_val
);
if
(
op_reg_val
->
logical_tensor_desc_infer_fn
)
{
if
(
op_reg_val
->
logical_tensor_desc_infer_fn
)
{
opkernel
->
tensor_desc_infer_fn_
=
op_reg_val
->
logical_tensor_desc_infer_fn
;
opkernel
->
tensor_desc_infer_fn_
=
op_reg_val
->
logical_tensor_desc_infer_fn
;
}
else
{
}
else
{
return
Error
::
UnimplementedError
();
return
Error
::
UnimplementedError
();
}
}
opkernel
->
data_type_infer_fn_
=
op_reg_val
->
data_type_infer_fn
;
opkernel
->
data_type_infer_fn_
=
op_reg_val
->
data_type_infer_fn
;
JUST
(
InitTensorTupleIndexes4Bns
(
JUST
(
InitTensorTupleIndexes4Bns
(
op_conf
,
input_arg_tuple
->
indexed_arg_name_and_index
(),
op_conf
,
input_arg_tuple
->
indexed_arg_name_and_index
(),
output_arg_tuple
->
indexed_arg_name_and_index
(),
&
opkernel
->
input_tuple_indexes4const_ibns_
,
output_arg_tuple
->
indexed_arg_name_and_index
(),
&
opkernel
->
input_tuple_indexes4const_ibns_
,
&
opkernel
->
input_tuple_indexes4mut_ibns_
,
&
opkernel
->
output_tuple_indexes4mut_obns_
,
&
opkernel
->
input_tuple_indexes4mut_ibns_
,
&
opkernel
->
output_tuple_indexes4mut_obns_
,
&
opkernel
->
output_tuple_indexes4mut2_obns_
));
&
opkernel
->
output_tuple_indexes4mut2_obns_
));
return
opkernel
;
return
opkernel
;
}
}
StatefulOpKernel
::~
StatefulOpKernel
()
=
default
;
StatefulOpKernel
::~
StatefulOpKernel
()
=
default
;
size_t
StatefulOpKernel
::
InferTmpSize
(
eager
::
CallContext
*
call_ctx
,
size_t
StatefulOpKernel
::
InferTmpSize
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
OpKernel
*
user_opkernel
)
const
{
const
user_op
::
OpKernel
*
user_opkernel
)
const
{
UserOpInferContext
op_infer_ctx
(
op_infer_ctx_helper_
.
get
(),
call_ctx
);
UserOpInferContext
op_infer_ctx
(
op_infer_ctx_helper_
.
get
(),
call_ctx
);
const
auto
&
InferTmpSizeFn
=
GetInferTmpSizeFn
(
user_opkernel
);
const
auto
&
InferTmpSizeFn
=
GetInferTmpSizeFn
(
user_opkernel
);
return
InferTmpSizeFn
(
&
op_infer_ctx
);
return
InferTmpSizeFn
(
&
op_infer_ctx
);
}
}
Maybe
<
void
>
StatefulOpKernel
::
ChooseOpKernel
(
eager
::
CallContext
*
call_ctx
,
Maybe
<
void
>
StatefulOpKernel
::
ChooseOpKernel
(
eager
::
CallContext
*
call_ctx
,
const
user_op
::
OpKernel
**
user_opkernel
,
const
user_op
::
OpKernel
**
user_opkernel
,
bool
*
need_temp_storage
)
{
bool
*
need_temp_storage
)
{
OF_PROFILER_RANGE_GUARD
(
"ChooseOpKernel"
);
OF_PROFILER_RANGE_GUARD
(
"ChooseOpKernel"
);
DataType
primary_dtype
=
kInvalidDataType
;
DataType
primary_dtype
=
kInvalidDataType
;
const
auto
&
inputs
=
call_ctx
->
inputs
();
const
auto
&
inputs
=
call_ctx
->
inputs
();
const
auto
&
outputs
=
call_ctx
->
outputs
();
const
auto
&
outputs
=
call_ctx
->
outputs
();
if
(
likely
(
!
inputs
->
empty
()))
{
if
(
likely
(
!
inputs
->
empty
()))
{
primary_dtype
=
(
*
inputs
)[
0
]
->
data_type
();
primary_dtype
=
(
*
inputs
)[
0
]
->
data_type
();
}
else
if
(
likely
(
!
outputs
->
empty
()))
{
}
else
if
(
likely
(
!
outputs
->
empty
()))
{
primary_dtype
=
(
*
outputs
)[
0
]
->
data_type
();
primary_dtype
=
(
*
outputs
)[
0
]
->
data_type
();
}
else
{
}
else
{
// do nothing
// do nothing
}
}
UserKernelRegContext
reg_ctx
(
reg_ctx_helper_
.
get
(),
call_ctx
);
UserKernelRegContext
reg_ctx
(
reg_ctx_helper_
.
get
(),
call_ctx
);
for
(
const
auto
&
pair
:
dtype2cached_kernels_
[
primary_dtype
])
{
for
(
const
auto
&
pair
:
dtype2cached_kernels_
[
primary_dtype
])
{
if
(
likely
(
pair
.
first
->
is_matched_hob
->
get
(
reg_ctx
)))
{
if
(
likely
(
pair
.
first
->
is_matched_hob
->
get
(
reg_ctx
)))
{
*
need_temp_storage
=
pair
.
first
->
need_temp_storage
;
*
need_temp_storage
=
pair
.
first
->
need_temp_storage
;
*
user_opkernel
=
pair
.
second
.
get
();
*
user_opkernel
=
pair
.
second
.
get
();
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
}
}
OF_PROFILER_RANGE_GUARD
(
"fallback"
);
OF_PROFILER_RANGE_GUARD
(
"fallback"
);
const
auto
&
op_type_name
=
user_op_conf_
->
op_type_name
();
const
auto
&
op_type_name
=
user_op_conf_
->
op_type_name
();
const
auto
*
kernel_reg_val
=
const
auto
*
kernel_reg_val
=
JUST
(
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpKernelRegistryResult
(
op_type_name
,
reg_ctx
));
JUST
(
user_op
::
UserOpRegistryMgr
::
Get
().
GetOpKernelRegistryResult
(
op_type_name
,
reg_ctx
));
CHECK_NOTNULL
(
kernel_reg_val
);
CHECK_NOTNULL
(
kernel_reg_val
);
auto
*
kernel
=
kernel_reg_val
->
create_fn
();
auto
*
kernel
=
kernel_reg_val
->
create_fn
();
dtype2cached_kernels_
[
primary_dtype
].
push_back
(
dtype2cached_kernels_
[
primary_dtype
].
push_back
(
{
kernel_reg_val
,
std
::
shared_ptr
<
const
user_op
::
OpKernel
>
(
kernel
)});
{
kernel_reg_val
,
std
::
shared_ptr
<
const
user_op
::
OpKernel
>
(
kernel
)});
infer_tmp_size_fn_map_
.
emplace
(
kernel
,
&
kernel_reg_val
->
infer_tmp_size_fn
);
infer_tmp_size_fn_map_
.
emplace
(
kernel
,
&
kernel_reg_val
->
infer_tmp_size_fn
);
*
need_temp_storage
=
kernel_reg_val
->
need_temp_storage
;
*
need_temp_storage
=
kernel_reg_val
->
need_temp_storage
;
*
user_opkernel
=
kernel
;
*
user_opkernel
=
kernel
;
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
}
}
void
StatefulOpKernel
::
TryInitOpKernelStateAndCache
(
eager
::
CallContext
*
call_ctx
,
void
StatefulOpKernel
::
TryInitOpKernelStateAndCache
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
DeviceCtx
*
device_ctx
,
const
user_op
::
OpKernel
*
op_kernel
,
const
user_op
::
OpKernel
*
op_kernel
,
user_op
::
OpKernelState
**
state
,
user_op
::
OpKernelState
**
state
,
user_op
::
OpKernelCache
**
cache
)
{
user_op
::
OpKernelCache
**
cache
)
{
UserKernelInitAndCacheContext
init_and_cache_ctx
(
init_and_cache_ctx_helper_
.
get
(),
call_ctx
,
UserKernelInitAndCacheContext
init_and_cache_ctx
(
init_and_cache_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
device_ctx
);
if
(
state
!=
nullptr
)
{
if
(
state
!=
nullptr
)
{
auto
it
=
op_kernel_state_map_
.
find
(
op_kernel
);
auto
it
=
op_kernel_state_map_
.
find
(
op_kernel
);
if
(
it
!=
op_kernel_state_map_
.
end
())
{
if
(
it
!=
op_kernel_state_map_
.
end
())
{
*
state
=
it
->
second
.
get
();
*
state
=
it
->
second
.
get
();
}
else
{
}
else
{
auto
created_state
=
op_kernel
->
CreateOpKernelState
(
&
init_and_cache_ctx
);
auto
created_state
=
op_kernel
->
CreateOpKernelState
(
&
init_and_cache_ctx
);
op_kernel_state_map_
.
emplace
(
op_kernel
,
created_state
);
op_kernel_state_map_
.
emplace
(
op_kernel
,
created_state
);
*
state
=
created_state
.
get
();
*
state
=
created_state
.
get
();
}
}
}
}
{
{
auto
&
cache_in_map
=
op_kernel_cache_map_
[
op_kernel
];
auto
&
cache_in_map
=
op_kernel_cache_map_
[
op_kernel
];
op_kernel
->
InitOpKernelCacheWithFlags
(
&
init_and_cache_ctx
,
op_kernel
->
InitOpKernelCacheWithFlags
(
&
init_and_cache_ctx
,
user_op
::
OpKernelCache
::
kAllMayChanged
,
&
cache_in_map
);
user_op
::
OpKernelCache
::
kAllMayChanged
,
&
cache_in_map
);
*
cache
=
cache_in_map
.
get
();
*
cache
=
cache_in_map
.
get
();
}
}
}
}
const
user_op
::
InferTmpSizeFn
&
StatefulOpKernel
::
GetInferTmpSizeFn
(
const
user_op
::
InferTmpSizeFn
&
StatefulOpKernel
::
GetInferTmpSizeFn
(
const
user_op
::
OpKernel
*
op_kernel
)
const
{
const
user_op
::
OpKernel
*
op_kernel
)
const
{
return
*
infer_tmp_size_fn_map_
.
at
(
op_kernel
);
return
*
infer_tmp_size_fn_map_
.
at
(
op_kernel
);
}
}
user_op
::
TensorDescInferFn
StatefulOpKernel
::
TensorDescInferFn
()
const
{
user_op
::
TensorDescInferFn
StatefulOpKernel
::
TensorDescInferFn
()
const
{
return
tensor_desc_infer_fn_
;
return
tensor_desc_infer_fn_
;
}
}
user_op
::
DataTypeInferFn
StatefulOpKernel
::
DataTypeInferFn
()
const
{
return
data_type_infer_fn_
;
}
user_op
::
DataTypeInferFn
StatefulOpKernel
::
DataTypeInferFn
()
const
{
return
data_type_infer_fn_
;
}
void
StatefulOpKernel
::
Compute
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
void
StatefulOpKernel
::
Compute
(
eager
::
CallContext
*
call_ctx
,
DeviceCtx
*
device_ctx
,
const
user_op
::
OpKernel
*
user_opkernel
,
const
user_op
::
OpKernel
*
user_opkernel
,
user_op
::
OpKernelState
*
state
,
user_op
::
OpKernelState
*
state
,
const
user_op
::
OpKernelCache
*
cache
)
const
{
const
user_op
::
OpKernelCache
*
cache
)
const
{
UserKernelComputeContext
compute_context
(
compute_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
UserKernelComputeContext
compute_context
(
compute_ctx_helper_
.
get
(),
call_ctx
,
device_ctx
);
auto
*
compute_ctx
=
&
compute_context
;
auto
*
compute_ctx
=
&
compute_context
;
OF_PROFILER_RANGE_GUARD
(
"Compute"
);
OF_PROFILER_RANGE_GUARD
(
"Compute"
);
if
(
Singleton
<
profiler
::
ProfileManager
>::
Get
())
{
if
(
Singleton
<
profiler
::
ProfileManager
>::
Get
())
{
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
const
auto
CalMemorySize
=
[
compute_ctx
](
const
one
::
ArgVec
&
args
)
->
int64_t
{
const
auto
CalMemorySize
=
[
compute_ctx
](
const
one
::
ArgVec
&
args
)
->
int64_t
{
const
auto
Func
=
[
compute_ctx
](
int64_t
mem_size
,
const
auto
&
pair
)
{
const
auto
Func
=
[
compute_ctx
](
int64_t
mem_size
,
const
auto
&
pair
)
{
const
auto
tensor
=
compute_ctx
->
Tensor4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
);
const
auto
tensor
=
compute_ctx
->
Tensor4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
);
return
mem_size
+
tensor
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
tensor
->
data_type
());
return
mem_size
+
tensor
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
tensor
->
data_type
());
};
};
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
static_cast
<
int64_t
>
(
0
),
Func
);
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
static_cast
<
int64_t
>
(
0
),
Func
);
};
};
#endif
#endif
auto
er_guard
=
CHECK_JUST
(
profiler
::
EventRecorder
::
CreateKernelEventRecorder
(
auto
er_guard
=
CHECK_JUST
(
profiler
::
EventRecorder
::
CreateKernelEventRecorder
(
op_type_name
(),
op_type_name
(),
#if defined(WITH_CUDA)
#if defined(WITH_CUDA)
|| defined(WITH_ROCM)
[
compute_ctx
,
CalMemorySize
]()
->
int64_t
{
[
compute_ctx
,
CalMemorySize
]()
->
int64_t
{
return
CalMemorySize
(
compute_ctx
->
inputs
())
+
CalMemorySize
(
compute_ctx
->
outputs
());
return
CalMemorySize
(
compute_ctx
->
inputs
())
+
CalMemorySize
(
compute_ctx
->
outputs
());
},
},
#endif
#endif
[
compute_ctx
]()
->
std
::
vector
<
Shape
View
>
{
[
compute_ctx
]()
->
std
::
vector
<
Shape
>
{
std
::
vector
<
Shape
View
>
shapes
;
std
::
vector
<
Shape
>
shapes
;
for
(
const
auto
&
pair
:
compute_ctx
->
inputs
())
{
for
(
const
auto
&
pair
:
compute_ctx
->
inputs
())
{
shapes
.
emplace_back
(
shapes
.
emplace_back
(
compute_ctx
->
TensorDesc4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
)
->
shape
());
compute_ctx
->
TensorDesc4ArgNameAndIndex
(
pair
.
first
,
pair
.
second
)
->
shape
());
}
}
return
shapes
;
return
shapes
;
}));
}));
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
}
else
{
}
else
{
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
user_opkernel
->
Compute
(
compute_ctx
,
state
,
cache
);
}
}
}
}
}
// namespace one
}
// namespace one
}
// namespace oneflow
}
// namespace oneflow
python/oneflow/test/modules/fused_dot_feature_interaction.py
0 → 100644
View file @
f262efc9
import
numpy
as
np
import
oneflow
as
flow
def
fused_dot_feature_interaction
(
x
,
y
,
self_interaction
=
False
,
output_padding
=
0
,
output_concat
=
None
,
dtype
=
flow
.
float32
):
# (bs, es) = x.shape
(
bs
,
dims
,
es
)
=
y
.
shape
if
self_interaction
:
offset
=
1
else
:
offset
=
0
li
=
flow
.
tensor
([
i
for
i
in
range
(
dims
+
1
)
for
j
in
range
(
i
+
offset
)])
lj
=
flow
.
tensor
([
j
for
i
in
range
(
dims
+
1
)
for
j
in
range
(
i
+
offset
)])
T
=
flow
.
cat
(
[
flow
.
reshape
(
x
,
(
bs
,
1
,
es
)),
y
,
],
dim
=
1
,
)
Z
=
flow
.
matmul
(
T
,
T
,
transpose_b
=
True
)
# gather_nd not support half, so cast to float32
Z
=
flow
.
cast
(
Z
,
flow
.
float32
)
Zflat
=
Z
[:,
li
,
lj
]
Zflat
=
flow
.
cast
(
Zflat
,
dtype
)
if
output_concat
is
not
None
:
R
=
flow
.
cat
([
output_concat
,
Zflat
],
dim
=
1
)
else
:
R
=
Zflat
if
output_padding
!=
0
:
padding_tensor
=
flow
.
tensor
(
np
.
zeros
((
bs
,
output_padding
)).
astype
(
np
.
float32
),
device
=
"cuda"
,
requires_grad
=
False
,
)
R
=
flow
.
cat
([
R
,
padding_tensor
],
dim
=
1
)
return
R
python/oneflow/test/profiler/test_profile_lenet.py
View file @
f262efc9
"""
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
import
os
import
os
import
unittest
import
unittest
import
oneflow.unittest
import
oneflow.unittest
import
oneflow
as
flow
import
oneflow
as
flow
import
oneflow.nn
as
nn
import
oneflow.nn
as
nn
import
oneflow.nn.functional
as
F
import
oneflow.nn.functional
as
F
import
oneflow.profiler
import
oneflow.profiler
from
oneflow.profiler.events
import
CustomEvent
,
KernelEvent
from
oneflow.profiler.events
import
CustomEvent
,
KernelEvent
class
LeNet
(
nn
.
Module
):
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
LeNet
,
self
).
__init__
()
super
(
LeNet
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
conv1
=
nn
.
Conv2d
(
3
,
6
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc1
=
nn
.
Linear
(
16
*
5
*
5
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out
=
F
.
relu
(
self
.
conv1
(
x
))
out
=
F
.
relu
(
self
.
conv1
(
x
))
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
F
.
relu
(
self
.
conv2
(
out
))
out
=
F
.
relu
(
self
.
conv2
(
out
))
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
F
.
max_pool2d
(
out
,
2
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
out
=
F
.
relu
(
self
.
fc1
(
out
))
out
=
F
.
relu
(
self
.
fc1
(
out
))
out
=
F
.
relu
(
self
.
fc2
(
out
))
out
=
F
.
relu
(
self
.
fc2
(
out
))
out
=
self
.
fc3
(
out
)
out
=
self
.
fc3
(
out
)
return
out
return
out
def
get_event
(
events
,
name
:
str
,
input_shapes
:
str
=
"-"
):
def
get_event
(
events
,
name
:
str
,
input_shapes
:
str
=
"-"
):
for
item
in
events
:
for
item
in
events
:
if
isinstance
(
item
,
CustomEvent
):
if
isinstance
(
item
,
CustomEvent
):
if
item
.
name
==
name
:
if
item
.
name
==
name
:
return
item
return
item
if
isinstance
(
item
,
KernelEvent
):
if
isinstance
(
item
,
KernelEvent
):
if
item
.
name
==
name
and
item
.
input_shapes
==
input_shapes
:
if
item
.
name
==
name
and
item
.
input_shapes
==
input_shapes
:
return
item
return
item
return
None
return
None
def
_test_lenet
(
def
_test_lenet
(
test_case
,
test_case
,
on_cuda
:
bool
,
on_cuda
:
bool
,
record_shapes
:
bool
,
record_shapes
:
bool
,
record_bandwidth_for_cuda
:
bool
=
False
,
record_bandwidth_for_cuda
:
bool
=
False
,
):
):
x
=
flow
.
randn
(
2
,
3
,
32
,
32
)
x
=
flow
.
randn
(
2
,
3
,
32
,
32
)
lenet
=
LeNet
()
lenet
=
LeNet
()
if
on_cuda
:
if
on_cuda
:
x
=
x
.
to
(
"cuda"
)
x
=
x
.
to
(
"cuda"
)
lenet
.
to
(
"cuda"
)
lenet
.
to
(
"cuda"
)
activities
=
[
oneflow
.
profiler
.
ProfilerActivity
.
CPU
]
activities
=
[
oneflow
.
profiler
.
ProfilerActivity
.
CPU
]
if
on_cuda
:
if
on_cuda
:
activities
.
append
(
oneflow
.
profiler
.
ProfilerActivity
.
CUDA
)
activities
.
append
(
oneflow
.
profiler
.
ProfilerActivity
.
CUDA
)
with
oneflow
.
profiler
.
profile
(
with
oneflow
.
profiler
.
profile
(
activities
=
activities
,
activities
=
activities
,
record_shapes
=
record_shapes
,
record_shapes
=
record_shapes
,
record_bandwidth_for_cuda
=
record_bandwidth_for_cuda
,
record_bandwidth_for_cuda
=
record_bandwidth_for_cuda
,
)
as
prof
:
)
as
prof
:
with
oneflow
.
profiler
.
record_function
(
"lenet_forward_total_time"
)
as
f
:
with
oneflow
.
profiler
.
record_function
(
"lenet_forward_total_time"
)
as
f
:
for
_
in
range
(
2
):
for
_
in
range
(
2
):
eager_res
=
lenet
(
x
)
eager_res
=
lenet
(
x
)
with
oneflow
.
profiler
.
record_function
(
"lenet_backward_total_time"
)
as
f
:
with
oneflow
.
profiler
.
record_function
(
"lenet_backward_total_time"
)
as
f
:
eager_res
.
sum
().
backward
()
eager_res
.
sum
().
backward
()
events
=
prof
.
key_averages
(
group_by_input_shape
=
True
)
events
=
prof
.
key_averages
(
group_by_input_shape
=
True
)
print
(
events
)
conv_event
=
get_event
(
conv_event
=
get_event
(
events
,
"conv2d"
,
"[(2,3,32,32), (6,3,5,5)]"
if
record_shapes
else
"-"
events
,
"conv2d"
,
"[(2,3,32,32), (6,3,5,5)]"
if
record_shapes
else
"-"
)
)
test_case
.
assertIsNotNone
(
conv_event
)
test_case
.
assertIsNotNone
(
conv_event
)
if
on_cuda
:
if
on_cuda
:
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time_total
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cuda_time_total
,
0.0
)
else
:
else
:
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
conv_event
.
cpu_time_total
,
0.0
)
test_case
.
assertEqual
(
conv_event
.
count
,
2
if
record_shapes
else
4
)
test_case
.
assertEqual
(
conv_event
.
count
,
2
if
record_shapes
else
4
)
if
record_bandwidth_for_cuda
and
on_cuda
:
if
record_bandwidth_for_cuda
and
on_cuda
:
test_case
.
assertNotEqual
(
conv_event
.
bandwidth
,
-
1
)
test_case
.
assertNotEqual
(
conv_event
.
bandwidth
,
-
1
)
relu_grad_event
=
get_event
(
relu_grad_event
=
get_event
(
events
,
"relu_grad"
,
"[(2,6,28,28), (2,6,28,28)]"
if
record_shapes
else
"-"
events
,
"relu_grad"
,
"[(2,6,28,28), (2,6,28,28)]"
if
record_shapes
else
"-"
)
)
test_case
.
assertIsNotNone
(
relu_grad_event
)
test_case
.
assertIsNotNone
(
relu_grad_event
)
if
on_cuda
:
if
on_cuda
:
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time_total
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cuda_time_total
,
0.0
)
else
:
else
:
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertGreater
(
relu_grad_event
.
cpu_time_total
,
0.0
)
test_case
.
assertEqual
(
relu_grad_event
.
count
,
1
if
record_shapes
else
4
)
test_case
.
assertEqual
(
relu_grad_event
.
count
,
1
if
record_shapes
else
4
)
if
record_bandwidth_for_cuda
and
on_cuda
:
if
record_bandwidth_for_cuda
and
on_cuda
:
test_case
.
assertNotEqual
(
relu_grad_event
.
bandwidth
,
-
1
)
test_case
.
assertNotEqual
(
relu_grad_event
.
bandwidth
,
-
1
)
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_forward_total_time"
))
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_forward_total_time"
))
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_backward_total_time"
))
test_case
.
assertIsNotNone
(
get_event
(
events
,
"lenet_backward_total_time"
))
class
TestProfileLenet
(
flow
.
unittest
.
TestCase
):
class
TestProfileLenet
(
flow
.
unittest
.
TestCase
):
def
test_lenet_cpu
(
test_case
):
def
test_lenet_cpu
(
test_case
):
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
True
)
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
True
)
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
False
)
_test_lenet
(
test_case
,
on_cuda
=
False
,
record_shapes
=
False
)
@
unittest
.
skipIf
(
os
.
getenv
(
"ONEFLOW_TEST_CPU_ONLY"
),
"only test cpu cases"
)
@
unittest
.
skipIf
(
os
.
getenv
(
"ONEFLOW_TEST_CPU_ONLY"
),
"only test cpu cases"
)
def
test_lenet_cuda
(
test_case
):
def
test_lenet_cuda
(
test_case
):
_test_lenet
(
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
False
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
False
)
)
_test_lenet
(
_test_lenet
(
test_case
,
test_case
,
on_cuda
=
True
,
on_cuda
=
True
,
record_shapes
=
False
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
False
,
record_bandwidth_for_cuda
=
False
,
)
)
_test_lenet
(
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
True
test_case
,
on_cuda
=
True
,
record_shapes
=
True
,
record_bandwidth_for_cuda
=
True
)
)
_test_lenet
(
_test_lenet
(
test_case
,
on_cuda
=
True
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
True
test_case
,
on_cuda
=
True
,
record_shapes
=
False
,
record_bandwidth_for_cuda
=
True
)
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment