Commit f262efc9 authored by yuguo's avatar yuguo
Browse files

Surpport profiler for DCU, surpport debug compiler

parent 3f56062c
...@@ -265,9 +265,9 @@ set(ROBIN_HOOD_HASHING_URL ...@@ -265,9 +265,9 @@ set(ROBIN_HOOD_HASHING_URL
use_mirror(VARIABLE ROBIN_HOOD_HASHING_URL URL ${ROBIN_HOOD_HASHING_URL}) use_mirror(VARIABLE ROBIN_HOOD_HASHING_URL URL ${ROBIN_HOOD_HASHING_URL})
set(ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467) set(ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467)
set(FMT_URL https://github.com/fmtlib/fmt/archive/48b7e3dafb27ece02cd6addc8bd1041c79d59c2c.zip) set(FMT_URL https://github.com/fmtlib/fmt/archive/fc07217d85e6dcec52878807d6bbd89a9d9156a5.zip)
use_mirror(VARIABLE FMT_URL URL ${FMT_URL}) use_mirror(VARIABLE FMT_URL URL ${FMT_URL})
set(FMT_MD5 45925a979ed7195e0c88a70be691de09) set(FMT_MD5 7d9bb2ececc9ede29cd35bdc42a7e22c)
set(KINETO_URL set(KINETO_URL
https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip) https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip)
......
...@@ -175,6 +175,8 @@ if (BUILD_ROCM) ...@@ -175,6 +175,8 @@ if (BUILD_ROCM)
add_definitions(-D__HIP_PLATFORM_HCC__) add_definitions(-D__HIP_PLATFORM_HCC__)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -mcmodel=large")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -mcmodel=large")
list(APPEND oneflow_third_party_libs hip::device) list(APPEND oneflow_third_party_libs hip::device)
list(APPEND oneflow_third_party_libs roc::hipblas) list(APPEND oneflow_third_party_libs roc::hipblas)
list(APPEND oneflow_third_party_libs hip::hipcub) list(APPEND oneflow_third_party_libs hip::hipcub)
......
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
// #include "fmt/core.h" #include "fmt/core.h"
// #include "fmt/format.h" #include "fmt/format.h"
#include "oneflow/core/profiler/event.h" #include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/util.h" #include "oneflow/core/profiler/util.h"
using json = nlohmann::json; using json = nlohmann::json;
namespace oneflow { namespace oneflow {
namespace profiler { namespace profiler {
nlohmann::json IEvent::ToJson() { nlohmann::json IEvent::ToJson() {
return json{{"name", name_}, {"time", GetDuration<double>()}, {"input_shapes", "-"}}; return json{{"name", name_}, {"time", GetDuration<double>()}, {"input_shapes", "-"}};
} }
void IEvent::SetStartedAt(double t) { started_at_ = t; } void IEvent::SetStartedAt(double t) { started_at_ = t; }
void IEvent::SetFinishedAt(double t) { finished_at_ = t; } void IEvent::SetFinishedAt(double t) { finished_at_ = t; }
void IEvent::Start() { SetStartedAt(GetTimeNow()); } void IEvent::Start() { SetStartedAt(GetTimeNow()); }
void IEvent::Finish() { SetFinishedAt(GetTimeNow()); } void IEvent::Finish() { SetFinishedAt(GetTimeNow()); }
bool IEvent::IsChildOf(const IEvent* e) { bool IEvent::IsChildOf(const IEvent* e) {
if (!e) { return false; } if (!e) { return false; }
if (this == e) { return false; } if (this == e) { return false; }
return GetStartedAt<double>() >= e->GetStartedAt<double>() return GetStartedAt<double>() >= e->GetStartedAt<double>()
&& GetFinishedAt<double>() <= e->GetFinishedAt<double>(); && GetFinishedAt<double>() <= e->GetFinishedAt<double>();
} }
const std::string& IEvent::GetName() const { return name_; } const std::string& IEvent::GetName() const { return name_; }
std::string CustomEvent::Key() { return name_; } std::string CustomEvent::Key() { return name_; }
nlohmann::json CustomEvent::ToJson() { nlohmann::json CustomEvent::ToJson() {
auto j = IEvent::ToJson(); auto j = IEvent::ToJson();
j["type"] = EventType::kCustom; j["type"] = EventType::kCustom;
j["custom_type"] = type_; j["custom_type"] = type_;
return j; return j;
} }
std::shared_ptr<CustomEvent> CustomEvent::Create(const std::string& name, CustomEventType type) { std::shared_ptr<CustomEvent> CustomEvent::Create(const std::string& name, CustomEventType type) {
return std::shared_ptr<CustomEvent>(new CustomEvent(name, type)); return std::shared_ptr<CustomEvent>(new CustomEvent(name, type));
} }
// std::string KernelEvent::Key() { return fmt::format("{}.{}", name_, GetFormatedInputShapes()); } std::string KernelEvent::Key() { return fmt::format("{}.{}", name_, GetFormatedInputShapes()); }
std::string KernelEvent::Key() { return "yuguo"; }
nlohmann::json KernelEvent::ToJson() {
nlohmann::json KernelEvent::ToJson() { auto j = IEvent::ToJson();
auto j = IEvent::ToJson(); j["type"] = EventType::kOneflowKernel;
j["type"] = EventType::kOneflowKernel; j["input_shapes"] = GetFormatedInputShapes();
j["input_shapes"] = GetFormatedInputShapes(); #if defined(WITH_CUDA) || defined(WITH_ROCM)
#if defined(WITH_CUDA) j["memory_size"] = memory_size_;
j["memory_size"] = memory_size_; if (!children_.empty()) { j["children"] = children_; }
if (!children_.empty()) { j["children"] = children_; } #endif // WITH_CUDA
#endif // WITH_CUDA return j;
return j; }
}
std::shared_ptr<KernelEvent> KernelEvent::Create(
std::shared_ptr<KernelEvent> KernelEvent::Create( const std::string& name, const std::function<std::vector<Shape>(void)>& shape_getter) {
const std::string& name, const std::function<std::vector<ShapeView>(void)>& shape_getter) { return std::shared_ptr<KernelEvent>(new KernelEvent(name, shape_getter));
return std::shared_ptr<KernelEvent>(new KernelEvent(name, shape_getter)); }
}
std::string KernelEvent::GetFormatedInputShapes(size_t max_num_to_format) {
void KernelEvent::RecordShape(const ShapeView& shape) { input_shapes_.emplace_back(shape); } if (input_shapes_.size() == 0) { return "-"; }
std::vector<std::string> shapes_formated(std::min(input_shapes_.size(), max_num_to_format));
std::string KernelEvent::GetFormatedInputShapes(size_t max_num_to_format) { for (auto i = 0; i < shapes_formated.size(); ++i) {
if (input_shapes_.size() == 0) { return "-"; } const std::string current_shape = input_shapes_[i].ToString();
std::vector<std::string> shapes_formated(std::min(input_shapes_.size(), max_num_to_format)); shapes_formated[i] = current_shape == "()" ? "scalar" : current_shape;
for (auto i = 0; i < shapes_formated.size(); ++i) { }
const std::string current_shape = input_shapes_[i].ToString(); if (input_shapes_.size() > max_num_to_format) { shapes_formated.emplace_back("..."); }
shapes_formated[i] = current_shape == "()" ? "scalar" : current_shape; return fmt::format("[{}]", fmt::join(shapes_formated, ", "));
} }
if (input_shapes_.size() > max_num_to_format) { shapes_formated.emplace_back("..."); }
// return fmt::format("[{}]", fmt::join(shapes_formated, ", ")); } // namespace profiler
return "yuguo";
}
} // namespace profiler
} // namespace oneflow } // namespace oneflow
\ No newline at end of file
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_ #ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_ #define ONEFLOW_CORE_PROFILER_EVENT_H_
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/shape_view.h"
namespace oneflow { namespace oneflow {
namespace profiler { namespace profiler {
class ProfileManager; class ProfileManager;
enum class EventType { enum class EventType {
kCustom, // has three kinds kCustom, // has three kinds
kOneflowKernel // OneFlow cpu/cuda kernel kOneflowKernel // OneFlow cpu/cuda kernel
}; };
enum class CustomEventType { enum class CustomEventType {
kDefault, // for record_function kDefault, // for record_function
kCudaKernel, // cuda kernel kCudaKernel, // cuda kernel
kCudaRuntime // something like cudaLaunchKernel kCudaRuntime // something like cudaLaunchKernel
}; };
enum class EventTimeUnit { kNS, kUS }; enum class EventTimeUnit { kNS, kUS };
class IEvent { class IEvent {
public: public:
OF_DISALLOW_COPY_AND_MOVE(IEvent); OF_DISALLOW_COPY_AND_MOVE(IEvent);
IEvent() = delete; IEvent() = delete;
IEvent(const std::string& name, EventTimeUnit time_unit) : name_(name), time_unit_(time_unit) {} IEvent(const std::string& name, EventTimeUnit time_unit) : name_(name), time_unit_(time_unit) {}
virtual std::string Key() = 0; virtual std::string Key() = 0;
virtual nlohmann::json ToJson(); virtual nlohmann::json ToJson();
virtual ~IEvent() = default; virtual ~IEvent() = default;
virtual void Start(); virtual void Start();
virtual void Finish(); virtual void Finish();
bool IsChildOf(const IEvent* e); bool IsChildOf(const IEvent* e);
const std::string& GetName() const; const std::string& GetName() const;
template<typename T> template<typename T>
const T GetDuration(EventTimeUnit time_unit = EventTimeUnit::kUS) const; const T GetDuration(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
template<typename T> template<typename T>
const T GetStartedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const; const T GetStartedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
template<typename T> template<typename T>
const T GetFinishedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const; const T GetFinishedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
protected: protected:
virtual void SetStartedAt(double t); virtual void SetStartedAt(double t);
virtual void SetFinishedAt(double t); virtual void SetFinishedAt(double t);
std::string name_; std::string name_;
EventTimeUnit time_unit_; EventTimeUnit time_unit_;
double started_at_ = 0; double started_at_ = 0;
double finished_at_ = 0; double finished_at_ = 0;
}; };
inline double ConvertTime(double time_, EventTimeUnit src_time_unit, EventTimeUnit dst_time_unit) { inline double ConvertTime(double time_, EventTimeUnit src_time_unit, EventTimeUnit dst_time_unit) {
if (src_time_unit == EventTimeUnit::kNS && dst_time_unit == EventTimeUnit::kUS) { if (src_time_unit == EventTimeUnit::kNS && dst_time_unit == EventTimeUnit::kUS) {
return time_ / 1000; return time_ / 1000;
} }
if (src_time_unit == EventTimeUnit::kUS && dst_time_unit == EventTimeUnit::kNS) { if (src_time_unit == EventTimeUnit::kUS && dst_time_unit == EventTimeUnit::kNS) {
return time_ * 1000; return time_ * 1000;
} }
return time_; return time_;
} }
template<> template<>
const inline double IEvent::GetStartedAt<double>(EventTimeUnit time_unit) const { const inline double IEvent::GetStartedAt<double>(EventTimeUnit time_unit) const {
return ConvertTime(started_at_, time_unit_, time_unit); return ConvertTime(started_at_, time_unit_, time_unit);
} }
template<> template<>
const inline time_t IEvent::GetStartedAt<time_t>(EventTimeUnit time_unit) const { const inline time_t IEvent::GetStartedAt<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetStartedAt<double>(time_unit)); return static_cast<time_t>(GetStartedAt<double>(time_unit));
} }
template<> template<>
const inline double IEvent::GetFinishedAt<double>(EventTimeUnit time_unit) const { const inline double IEvent::GetFinishedAt<double>(EventTimeUnit time_unit) const {
return ConvertTime(finished_at_, time_unit_, time_unit); return ConvertTime(finished_at_, time_unit_, time_unit);
} }
template<> template<>
const inline time_t IEvent::GetFinishedAt<time_t>(EventTimeUnit time_unit) const { const inline time_t IEvent::GetFinishedAt<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetFinishedAt<double>(time_unit)); return static_cast<time_t>(GetFinishedAt<double>(time_unit));
} }
template<> template<>
const inline double IEvent::GetDuration<double>(EventTimeUnit time_unit) const { const inline double IEvent::GetDuration<double>(EventTimeUnit time_unit) const {
return GetFinishedAt<double>(time_unit) - GetStartedAt<double>(time_unit); return GetFinishedAt<double>(time_unit) - GetStartedAt<double>(time_unit);
} }
template<> template<>
const inline time_t IEvent::GetDuration<time_t>(EventTimeUnit time_unit) const { const inline time_t IEvent::GetDuration<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetDuration<double>(time_unit)); return static_cast<time_t>(GetDuration<double>(time_unit));
} }
class CustomEvent final : public IEvent { class CustomEvent final : public IEvent {
public: public:
friend class ProfileManager; friend class ProfileManager;
std::string Key() override; std::string Key() override;
nlohmann::json ToJson() override; nlohmann::json ToJson() override;
static std::shared_ptr<CustomEvent> Create(const std::string& name, static std::shared_ptr<CustomEvent> Create(const std::string& name,
CustomEventType type = CustomEventType::kDefault); CustomEventType type = CustomEventType::kDefault);
private: private:
CustomEventType type_; CustomEventType type_;
CustomEvent(const std::string& custom_name, CustomEventType type) CustomEvent(const std::string& custom_name, CustomEventType type)
: IEvent(custom_name, : IEvent(custom_name,
type == CustomEventType::kDefault ? EventTimeUnit::kNS : EventTimeUnit::kUS), type == CustomEventType::kDefault ? EventTimeUnit::kNS : EventTimeUnit::kUS),
type_(type) {} type_(type) {}
}; };
class KernelEvent final : public IEvent { class KernelEvent final : public IEvent {
public: public:
std::string Key() override; std::string Key() override;
nlohmann::json ToJson() override; nlohmann::json ToJson() override;
static std::shared_ptr<KernelEvent> Create( static std::shared_ptr<KernelEvent> Create(
const std::string& name, const std::function<std::vector<ShapeView>(void)>& shape_getter); const std::string& name, const std::function<std::vector<Shape>(void)>& shape_getter);
void RecordShape(const ShapeView& shape); #if defined(WITH_CUDA) || defined(WITH_ROCM)
void SetMemorySize(int64_t memory_size) { memory_size_ = memory_size; }
#if defined(WITH_CUDA) void AddChildEvent(const std::shared_ptr<IEvent>& e) { children_.emplace(e); }
void SetMemorySize(int64_t memory_size) { memory_size_ = memory_size; } bool AddChildEventIfSo(const std::shared_ptr<IEvent>& e) {
void AddChildEvent(const std::shared_ptr<IEvent>& e) { children_.emplace(e); } if (e->IsChildOf(dynamic_cast<IEvent*>(this))) {
bool AddChildEventIfSo(const std::shared_ptr<IEvent>& e) { children_.emplace(e);
if (e->IsChildOf(dynamic_cast<IEvent*>(this))) { return true;
children_.emplace(e); }
return true; return false;
} }
return false; bool HasChildEvent(const std::shared_ptr<IEvent>& e) { return children_.count(e); }
} void WalkAmongChildren(const std::function<void(const std::shared_ptr<IEvent>& e)>& f) const {
bool HasChildEvent(const std::shared_ptr<IEvent>& e) { return children_.count(e); } for (const auto& x : children_) { f(x); }
void WalkAmongChildren(const std::function<void(const std::shared_ptr<IEvent>& e)>& f) const { }
for (const auto& x : children_) { f(x); } #endif // WITH_CUDA
}
#endif // WITH_CUDA private:
KernelEvent(const std::string& kernel_name,
private: const std::function<std::vector<Shape>(void)>& shape_getter)
KernelEvent(const std::string& kernel_name, : IEvent(kernel_name, EventTimeUnit::kNS) {
const std::function<std::vector<ShapeView>(void)>& shape_getter) if (shape_getter) { input_shapes_ = shape_getter(); }
: IEvent(kernel_name, EventTimeUnit::kNS) { }
if (shape_getter) { input_shapes_ = shape_getter(); }
} #if defined(WITH_CUDA) || defined(WITH_ROCM)
int64_t memory_size_ = -1;
#if defined(WITH_CUDA) std::set<std::shared_ptr<IEvent>> children_;
int64_t memory_size_ = -1; #endif // WITH_CUDA
std::set<std::shared_ptr<IEvent>> children_;
#endif // WITH_CUDA std::vector<Shape> input_shapes_;
std::string GetFormatedInputShapes(size_t max_num_to_format = 4);
std::vector<ShapeView> input_shapes_; };
std::string GetFormatedInputShapes(size_t max_num_to_format = 4);
}; } // namespace profiler
} // namespace oneflow
} // namespace profiler
} // namespace oneflow namespace nlohmann {
namespace nlohmann { inline void to_json(json& j, const std::shared_ptr<::oneflow::profiler::IEvent>& event) {
j = event->ToJson();
inline void to_json(json& j, const std::shared_ptr<::oneflow::profiler::IEvent>& event) { }
j = event->ToJson();
} } // namespace nlohmann
} // namespace nlohmann #endif // ONEFLOW_CORE_PROFILER_EVENT_H_
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
...@@ -32,13 +32,13 @@ std::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const st ...@@ -32,13 +32,13 @@ std::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const st
Maybe<EventRecorder> EventRecorder::CreateKernelEventRecorder( Maybe<EventRecorder> EventRecorder::CreateKernelEventRecorder(
const std::string& name, const std::string& name,
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
const std::function<int64_t()>& memory_size_getter, const std::function<int64_t()>& memory_size_getter,
#endif #endif
const ShapeGetterFuncType& shape_getter) { const ShapeGetterFuncType& shape_getter) {
auto pmgr = Singleton<ProfileManager>::Get(); auto pmgr = Singleton<ProfileManager>::Get();
if (pmgr) { if (pmgr) {
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
if (pmgr->use_cpu_ || pmgr->use_cuda_) { if (pmgr->use_cpu_ || pmgr->use_cuda_) {
auto event = KernelEvent::Create(name, pmgr->record_shapes_ ? shape_getter : nullptr); auto event = KernelEvent::Create(name, pmgr->record_shapes_ ? shape_getter : nullptr);
if (pmgr->use_cuda_) { if (pmgr->use_cuda_) {
......
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_ #ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_ #define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/profiler/event.h" #include "oneflow/core/profiler/event.h"
namespace oneflow { namespace oneflow {
namespace profiler { namespace profiler {
class EventRecorder { class EventRecorder {
public: public:
using ShapeGetterFuncType = std::function<std::vector<ShapeView>(void)>; using ShapeGetterFuncType = std::function<std::vector<Shape>(void)>;
OF_DISALLOW_COPY_AND_MOVE(EventRecorder); OF_DISALLOW_COPY_AND_MOVE(EventRecorder);
explicit EventRecorder(const std::shared_ptr<IEvent>& event) : event_(event) { explicit EventRecorder(const std::shared_ptr<IEvent>& event) : event_(event) {
CHECK_JUST(RegisterEventToProfileManager(event)); CHECK_JUST(RegisterEventToProfileManager(event));
event_->Start(); event_->Start();
} }
Maybe<void> RegisterEventToProfileManager(const std::shared_ptr<IEvent>& event); Maybe<void> RegisterEventToProfileManager(const std::shared_ptr<IEvent>& event);
~EventRecorder() { ~EventRecorder() {
if (event_) { if (event_) {
event_->Finish(); event_->Finish();
event_.reset(); event_.reset();
} }
} }
static std::shared_ptr<EventRecorder> CreateCustomEventRecorder(const std::string& name); static std::shared_ptr<EventRecorder> CreateCustomEventRecorder(const std::string& name);
static Maybe<EventRecorder> CreateKernelEventRecorder( static Maybe<EventRecorder> CreateKernelEventRecorder(
const std::string& name, const std::string& name,
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
const std::function<int64_t()>& memory_size_getter, const std::function<int64_t()>& memory_size_getter,
#endif #endif
const ShapeGetterFuncType& shape_getter); const ShapeGetterFuncType& shape_getter);
private: private:
std::shared_ptr<IEvent> event_; std::shared_ptr<IEvent> event_;
}; };
} // namespace profiler } // namespace profiler
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_ #endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
...@@ -17,7 +17,11 @@ limitations under the License. ...@@ -17,7 +17,11 @@ limitations under the License.
#include "oneflow/core/profiler/kernel.h" #include "oneflow/core/profiler/kernel.h"
#include "oneflow/core/profiler/profiler.h" #include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/kernel/kernel.h" #include "oneflow/core/kernel/kernel.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/lazy/actor/actor_context.h" #include "oneflow/core/lazy/actor/actor_context.h"
namespace oneflow { namespace oneflow {
...@@ -43,6 +47,11 @@ thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr; ...@@ -43,6 +47,11 @@ thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
thread_local cudaEvent_t cuda_memory_bandwidth_profile_end_event = nullptr; thread_local cudaEvent_t cuda_memory_bandwidth_profile_end_event = nullptr;
#endif // WITH_CUDA #endif // WITH_CUDA
#if defined(WITH_ROCM)
thread_local hipEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
thread_local hipEvent_t cuda_memory_bandwidth_profile_end_event = nullptr;
#endif // WITH_ROCM
} // namespace } // namespace
void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel) { void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel) {
...@@ -61,6 +70,22 @@ void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* ...@@ -61,6 +70,22 @@ void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel*
} }
if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); } if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); }
#endif // WITH_CUDA #endif // WITH_CUDA
#if defined(WITH_ROCM)
if (profile_cuda_memory_bandwidth) {
auto* actor_context_provider = dynamic_cast<ActorContextProvider*>(kernel_ctx);
auto* cuda_stream = dynamic_cast<ep::CudaStream*>(kernel_ctx->stream());
if (cuda_stream != nullptr && actor_context_provider != nullptr) {
CHECK(cuda_memory_bandwidth_profile_start_event == nullptr);
CHECK(cuda_memory_bandwidth_profile_end_event == nullptr);
OF_CUDA_CHECK(hipEventCreate(&cuda_memory_bandwidth_profile_start_event));
OF_CUDA_CHECK(hipEventCreate(&cuda_memory_bandwidth_profile_end_event));
OF_CUDA_CHECK(
hipEventRecord(cuda_memory_bandwidth_profile_start_event, cuda_stream->cuda_stream()));
}
}
if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); }
#endif // WITH_ROCM
} }
void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel) { void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel) {
...@@ -103,6 +128,45 @@ void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* k ...@@ -103,6 +128,45 @@ void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* k
} }
} }
#endif // WITH_CUDA #endif // WITH_CUDA
#if defined(WITH_ROCM)
if (profile_kernel_forward_range) { OF_PROFILER_RANGE_POP(); }
// The memory bandwidth profiler only works in lazy mode.
if (profile_cuda_memory_bandwidth) {
auto* cuda_stream = dynamic_cast<ep::CudaStream*>(kernel_ctx->stream());
auto* actor_context_provider = dynamic_cast<ActorContextProvider*>(kernel_ctx);
if (cuda_stream != nullptr && actor_context_provider != nullptr) {
hipEvent_t start_event = cuda_memory_bandwidth_profile_start_event;
hipEvent_t end_event = cuda_memory_bandwidth_profile_end_event;
cuda_memory_bandwidth_profile_start_event = nullptr;
cuda_memory_bandwidth_profile_end_event = nullptr;
CHECK_NOTNULL(start_event);
CHECK_NOTNULL(end_event);
OF_CUDA_CHECK(hipEventRecord(end_event, cuda_stream->cuda_stream()));
int64_t memory_size = 0;
for (const auto& bn : kernel->op_attribute().input_bns()) {
const Blob* blob = kernel_ctx->BnInOp2Blob(bn);
if (blob) { memory_size += blob->ByteSizeOfBlobBody(); }
}
for (const auto& bn : kernel->op_attribute().output_bns()) {
const Blob* blob = kernel_ctx->BnInOp2Blob(bn);
if (blob) { memory_size += blob->ByteSizeOfBlobBody(); }
}
const std::string op_name = kernel->op_conf().name();
actor_context_provider->GetActorContext()->AddCallback(
[start_event, end_event, memory_size, op_name]() {
float elapsed_ms = 0;
OF_CUDA_CHECK(hipEventElapsedTime(&elapsed_ms, start_event, end_event));
OF_CUDA_CHECK(hipEventDestroy(start_event));
OF_CUDA_CHECK(hipEventDestroy(end_event));
double bandwidth =
static_cast<double>(memory_size) / (1024.0 * 1024.0 * 1024.0) / (elapsed_ms / 1000);
LOG(INFO) << "PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: " << op_name
<< " elapsed(ms): " << elapsed_ms << " memory_size(Byte): " << memory_size
<< " bandwidth(GB/s): " << bandwidth;
});
}
}
#endif // WITH_ROCM
} }
} // namespace profiler } // namespace profiler
......
...@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and ...@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
#include "oneflow/core/profiler/kineto_shim.h" #include "oneflow/core/profiler/kineto_shim.h"
#include "libkineto.h" #include "libkineto.h"
......
...@@ -16,7 +16,7 @@ limitations under the License. ...@@ -16,7 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_ #ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_ #define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
#include <string> #include <string>
#include <memory> #include <memory>
......
...@@ -15,12 +15,12 @@ limitations under the License. ...@@ -15,12 +15,12 @@ limitations under the License.
*/ */
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
// #include "fmt/core.h" #include "fmt/core.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "oneflow/core/profiler/kineto_shim.h" #include "oneflow/core/profiler/kineto_shim.h"
#include "oneflow/core/profiler/profile_manager.h" #include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event.h" #include "oneflow/core/profiler/event.h"
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
#include <libkineto.h> #include <libkineto.h>
#endif // WITH_CUDA #endif // WITH_CUDA
...@@ -48,7 +48,7 @@ std::string ProfileManager::DumpResultsJson() { ...@@ -48,7 +48,7 @@ std::string ProfileManager::DumpResultsJson() {
} }
std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() { std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
auto trace = StopTrace(); auto trace = StopTrace();
const auto& kineto_events = *(trace.get()->activities()); const auto& kineto_events = *(trace.get()->activities());
std::set<std::shared_ptr<IEvent>> custom_events; std::set<std::shared_ptr<IEvent>> custom_events;
...@@ -77,7 +77,7 @@ std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() { ...@@ -77,7 +77,7 @@ std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
while (!events_.empty()) { while (!events_.empty()) {
auto evt = events_.front(); auto evt = events_.front();
events_.pop(); events_.pop();
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
auto evt_kernel = std::dynamic_pointer_cast<KernelEvent>(evt); auto evt_kernel = std::dynamic_pointer_cast<KernelEvent>(evt);
if (evt_kernel) { if (evt_kernel) {
std::set<int64_t> current_corr_ids; std::set<int64_t> current_corr_ids;
...@@ -106,8 +106,7 @@ std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) { ...@@ -106,8 +106,7 @@ std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) {
} else { } else {
event_recorders_last_id_[name]++; event_recorders_last_id_[name]++;
} }
// return fmt::format("{}.{}", name, event_recorders_last_id_[name]); return fmt::format("{}.{}", name, event_recorders_last_id_[name]);
return "yuguo";
} }
} // namespace profiler } // namespace profiler
......
...@@ -37,7 +37,7 @@ class ProfileManager { ...@@ -37,7 +37,7 @@ class ProfileManager {
use_cuda_(use_cuda), use_cuda_(use_cuda),
record_shapes_(record_shapes), record_shapes_(record_shapes),
record_bandwidth_(record_bandwidth) { record_bandwidth_(record_bandwidth) {
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
std::set<ActivityType> activities{}; std::set<ActivityType> activities{};
if (use_cpu) { activities.insert(ActivityType::CPU); } if (use_cpu) { activities.insert(ActivityType::CPU); }
if (use_cuda) { activities.insert(ActivityType::CUDA); } if (use_cuda) { activities.insert(ActivityType::CUDA); }
......
...@@ -20,11 +20,20 @@ limitations under the License. ...@@ -20,11 +20,20 @@ limitations under the License.
#include "oneflow/core/profiler/event_recorder.h" #include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/vm_util.h"
#ifdef OF_ENABLE_PROFILER #ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_profile.h>
#include <roctracer_roctx.h>
#include <sys/syscall.h>
#include <iostream>
#include "oneflow/core/device/cuda_util.h"
#else
#include <nvtx3/nvToolsExt.h> #include <nvtx3/nvToolsExt.h>
#include <sys/syscall.h> #include <sys/syscall.h>
#include <iostream> #include <iostream>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cuda_util.h"
#endif
#endif // OF_ENABLE_PROFILER #endif // OF_ENABLE_PROFILER
namespace oneflow { namespace oneflow {
...@@ -33,6 +42,16 @@ namespace profiler { ...@@ -33,6 +42,16 @@ namespace profiler {
void NameThisHostThread(const std::string& name) { void NameThisHostThread(const std::string& name) {
#ifdef OF_ENABLE_PROFILER #ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
static thread_local std::unique_ptr<std::string> thread_name_prefix;
if (!thread_name_prefix) {
thread_name_prefix.reset(
new std::string(GetStringFromEnv("ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX", "")));
}
const std::string name_with_prefix = *thread_name_prefix + name;
// nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());
roctxMarkA(name_with_prefix.c_str());
#else
static thread_local std::unique_ptr<std::string> thread_name_prefix; static thread_local std::unique_ptr<std::string> thread_name_prefix;
if (!thread_name_prefix) { if (!thread_name_prefix) {
thread_name_prefix.reset( thread_name_prefix.reset(
...@@ -40,18 +59,27 @@ void NameThisHostThread(const std::string& name) { ...@@ -40,18 +59,27 @@ void NameThisHostThread(const std::string& name) {
} }
const std::string name_with_prefix = *thread_name_prefix + name; const std::string name_with_prefix = *thread_name_prefix + name;
nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str()); nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());
#endif
#endif // OF_ENABLE_PROFILER #endif // OF_ENABLE_PROFILER
} }
void RangePush(const std::string& name) { void RangePush(const std::string& name) {
#ifdef OF_ENABLE_PROFILER #ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePushA(name.c_str());
#else
nvtxRangePushA(name.c_str()); nvtxRangePushA(name.c_str());
#endif
#endif // OF_ENABLE_PROFILER #endif // OF_ENABLE_PROFILER
} }
void RangePop() { void RangePop() {
#ifdef OF_ENABLE_PROFILER #ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePop();
#else
nvtxRangePop(); nvtxRangePop();
#endif
#endif // OF_ENABLE_PROFILER #endif // OF_ENABLE_PROFILER
} }
...@@ -82,13 +110,21 @@ void LogHostMemoryUsage(const std::string& name) { ...@@ -82,13 +110,21 @@ void LogHostMemoryUsage(const std::string& name) {
void ProfilerStart() { void ProfilerStart() {
#ifdef OF_ENABLE_PROFILER #ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipProfilerStart());
#else
OF_CUDA_CHECK(cudaProfilerStart()); OF_CUDA_CHECK(cudaProfilerStart());
#endif
#endif // OF_ENABLE_PROFILER #endif // OF_ENABLE_PROFILER
} }
void ProfilerStop() { void ProfilerStop() {
#ifdef OF_ENABLE_PROFILER #ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipProfilerStop());
#else
OF_CUDA_CHECK(cudaProfilerStop()); OF_CUDA_CHECK(cudaProfilerStop());
#endif
#endif // OF_ENABLE_PROFILER #endif // OF_ENABLE_PROFILER
} }
...@@ -105,6 +141,9 @@ Maybe<std::string> DisableProfilerAndReturnResult() { ...@@ -105,6 +141,9 @@ Maybe<std::string> DisableProfilerAndReturnResult() {
#if defined(WITH_CUDA) #if defined(WITH_CUDA)
OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaDeviceSynchronize());
#endif // WITH_CUDA #endif // WITH_CUDA
#if defined(WITH_ROCM)
OF_CUDA_CHECK(hipDeviceSynchronize());
#endif // WITH_ROCM
auto* pmgr = JUST(SingletonMaybe<ProfileManager>()); auto* pmgr = JUST(SingletonMaybe<ProfileManager>());
std::string results = pmgr->DumpResultsJson(); std::string results = pmgr->DumpResultsJson();
Singleton<ProfileManager>::Delete(); Singleton<ProfileManager>::Delete();
......
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_ #ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_ #define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h" #include "oneflow/core/common/data_type.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/core/device/cuda_pseudo_half.h" #include "oneflow/core/device/cuda_pseudo_half.h"
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include <cuda_fp16.h> #include <cuda_fp16.h>
#define MATH_FUNC_F(name, x) name##f(x) #define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x) #define MATH_FUNC_D(name, x) name(x)
#elif defined(__HIPCC__) #elif defined(__HIPCC__)
#include <cmath> #include <cmath>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__) #if defined(__HIP_DEVICE_COMPILE__)
#define MATH_FUNC_F(name, x) name##f(x) #define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x) #define MATH_FUNC_D(name, x) name(x)
#else #else
#define MATH_FUNC_F(name, x) std::name(x) #define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x) #define MATH_FUNC_D(name, x) std::name(x)
#endif #endif
#else #else
#include <cmath> #include <cmath>
#define MATH_FUNC_F(name, x) std::name(x) #define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x) #define MATH_FUNC_D(name, x) std::name(x)
#endif #endif
namespace oneflow { namespace oneflow {
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \ #define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
template<typename T> \ template<typename T> \
struct func_prefix##Functor; struct func_prefix##Functor;
OF_PP_FOR_EACH_TUPLE(DECLARE_UNARY_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ) OF_PP_FOR_EACH_TUPLE(DECLARE_UNARY_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ)
template<typename T> template<typename T>
struct AbsFunctor { struct AbsFunctor {
static OF_DEVICE_FUNC T Forward(const T x) { static OF_DEVICE_FUNC T Forward(const T x) {
if (x == T(0)) if (x == T(0))
return T(0); return T(0);
else else
return x < T(0) ? -x : x; return x < T(0) ? -x : x;
} }
static OF_DEVICE_FUNC T Backward(const T x, const T dy) { static OF_DEVICE_FUNC T Backward(const T x, const T dy) {
if (x == T(0)) if (x == T(0))
return T(0); return T(0);
else else
return x < T(0) ? -dy : dy; return x < T(0) ? -dy : dy;
} }
}; };
template<typename T> template<typename T>
struct SignFunctor { struct SignFunctor {
static OF_DEVICE_FUNC T Forward(const T x) { return (T(0) < x) - (x < T(0)); } static OF_DEVICE_FUNC T Forward(const T x) { return (T(0) < x) - (x < T(0)); }
static OF_DEVICE_FUNC T Backward(const T x, const T dy) { return T(0); } static OF_DEVICE_FUNC T Backward(const T x, const T dy) { return T(0); }
}; };
template<> template<>
struct RsqrtFunctor<float> { struct RsqrtFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { static OF_DEVICE_FUNC float Forward(const float x) {
#if defined(__CUDACC__) #if defined(__CUDACC__)
return rsqrtf(x); return rsqrtf(x);
#elif defined(__HIP_DEVICE_COMPILE__) #elif defined(__HIP_DEVICE_COMPILE__)
return rsqrtf(x); return rsqrtf(x);
#else #else
return 1.0f / std::sqrt(x); return 1.0f / std::sqrt(x);
#endif #endif
} }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-1.0f / (2.0f * MATH_FUNC_F(sqrt, x * x * x))); return dy * (-1.0f / (2.0f * MATH_FUNC_F(sqrt, x * x * x)));
} }
}; };
template<> template<>
struct RsqrtFunctor<double> { struct RsqrtFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { static OF_DEVICE_FUNC double Forward(const double x) {
#if defined(__CUDACC__) #if defined(__CUDACC__)
return rsqrt(x); return rsqrt(x);
#elif defined(__HIP_DEVICE_COMPILE__) #elif defined(__HIP_DEVICE_COMPILE__)
return rsqrt(x); return rsqrt(x);
#else #else
return 1.0 / std::sqrt(x); return 1.0 / std::sqrt(x);
#endif #endif
} }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-1.0 / (2.0 * MATH_FUNC_D(sqrt, x * x * x))); return dy * (-1.0 / (2.0 * MATH_FUNC_D(sqrt, x * x * x)));
} }
}; };
// float version // float version
template<> template<>
struct AcosFunctor<float> { struct AcosFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acos, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acos, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * -RsqrtFunctor<float>::Forward(1.0f - x * x); return dy * -RsqrtFunctor<float>::Forward(1.0f - x * x);
} }
}; };
template<> template<>
struct AcoshFunctor<float> { struct AcoshFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acosh, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acosh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(x * x - 1.0f); return dy * RsqrtFunctor<float>::Forward(x * x - 1.0f);
} }
}; };
template<> template<>
struct AsinFunctor<float> { struct AsinFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asin, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asin, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(1.0f - x * x); return dy * RsqrtFunctor<float>::Forward(1.0f - x * x);
} }
}; };
template<> template<>
struct AsinhFunctor<float> { struct AsinhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asinh, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asinh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(1.0f + x * x); return dy * RsqrtFunctor<float>::Forward(1.0f + x * x);
} }
}; };
template<> template<>
struct AtanFunctor<float> { struct AtanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atan, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atan, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (1.0f + x * x)); return dy * (1.0f / (1.0f + x * x));
} }
}; };
template<> template<>
struct AtanhFunctor<float> { struct AtanhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atanh, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atanh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (1.0f - x * x)); return dy * (1.0f / (1.0f - x * x));
} }
}; };
template<> template<>
struct NotEqualZeroFunctor<float> { struct NotEqualZeroFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; } static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
}; };
template<> template<>
struct CeilFunctor<float> { struct CeilFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
}; };
template<> template<>
struct CosFunctor<float> { struct CosFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cos, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cos, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-MATH_FUNC_F(sin, x)); return dy * (-MATH_FUNC_F(sin, x));
} }
}; };
template<> template<>
struct CoshFunctor<float> { struct CoshFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cosh, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cosh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(sinh, x); return dy * MATH_FUNC_F(sinh, x);
} }
}; };
template<> template<>
struct ErfFunctor<float> { struct ErfFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erf, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erf, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * 2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x); return dy * 2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);
} }
}; };
template<> template<>
struct ErfcFunctor<float> { struct ErfcFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erfc, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erfc, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * -2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x); return dy * -2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);
} }
}; };
template<> template<>
struct ExpFunctor<float> { struct ExpFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(exp, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(exp, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(exp, x); return dy * MATH_FUNC_F(exp, x);
} }
}; };
template<> template<>
struct Expm1Functor<float> { struct Expm1Functor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(expm1, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(expm1, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(exp, x); return dy * MATH_FUNC_F(exp, x);
} }
}; };
template<> template<>
struct FloorFunctor<float> { struct FloorFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(floor, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(floor, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
}; };
template<> template<>
struct LgammaFunctor<float> { struct LgammaFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(lgamma, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(lgamma, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
// TODO(chengcheng): return: dy * digamma(x) // TODO(chengcheng): return: dy * digamma(x)
assert(false); // assert(false);
return 0.0f; return 0.0f;
} }
}; };
template<> template<>
struct LogFunctor<float> { struct LogFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); }
}; };
template<> template<>
struct Log2Functor<float> { struct Log2Functor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f))); return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f)));
} }
}; };
template<> template<>
struct Log1pFunctor<float> { struct Log1pFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (x + 1.0f)); return dy * (1.0f / (x + 1.0f));
} }
}; };
template<> template<>
struct LogSigmoidFunctor<float> { struct LogSigmoidFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { static OF_DEVICE_FUNC float Forward(const float x) {
return -MATH_FUNC_F(log, (1.0f + MATH_FUNC_F(exp, -x))); return -MATH_FUNC_F(log, (1.0f + MATH_FUNC_F(exp, -x)));
} }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (MATH_FUNC_F(exp, x) + 1.0f)); return dy * (1.0f / (MATH_FUNC_F(exp, x) + 1.0f));
} }
}; };
template<> template<>
struct NegativeFunctor<float> { struct NegativeFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return -x; } static OF_DEVICE_FUNC float Forward(const float x) { return -x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return -dy; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return -dy; }
}; };
template<> template<>
struct ReciprocalFunctor<float> { struct ReciprocalFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / x; } static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-1.0f / (x * x)); return dy * (-1.0f / (x * x));
} }
}; };
template<> template<>
struct ReciprocalNoNanFunctor<float> { struct ReciprocalNoNanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { static OF_DEVICE_FUNC float Forward(const float x) {
if (fabsf(x) <= 0.0f) { return 0.0f; } if (fabsf(x) <= 0.0f) { return 0.0f; }
return 1.0f / x; return 1.0f / x;
} }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
if (fabsf(x) <= 0.0f) { return 0.0f; } if (fabsf(x) <= 0.0f) { return 0.0f; }
return dy * (-1.0f / (x * x)); return dy * (-1.0f / (x * x));
} }
}; };
template<> template<>
struct RintFunctor<float> { struct RintFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(rint, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(rint, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
}; };
template<> template<>
struct RoundFunctor<float> { struct RoundFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(nearbyint, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(nearbyint, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
}; };
template<> template<>
struct SigmoidFunctor<float> { struct SigmoidFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { static OF_DEVICE_FUNC float Forward(const float x) {
return 1.0f / (1.0f + MATH_FUNC_F(exp, -x)); return 1.0f / (1.0f + MATH_FUNC_F(exp, -x));
} }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
float y = 1.0f / (1.0f + MATH_FUNC_F(exp, -x)); float y = 1.0f / (1.0f + MATH_FUNC_F(exp, -x));
return dy * (y * (1.0f - y)); return dy * (y * (1.0f - y));
} }
}; };
template<> template<>
struct SinFunctor<float> { struct SinFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sin, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sin, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(cos, x); return dy * MATH_FUNC_F(cos, x);
} }
}; };
template<> template<>
struct SinhFunctor<float> { struct SinhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sinh, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sinh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(cosh, x); return dy * MATH_FUNC_F(cosh, x);
} }
}; };
template<> template<>
struct SqrtFunctor<float> { struct SqrtFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sqrt, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sqrt, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * 0.5f / MATH_FUNC_F(sqrt, x); return dy * 0.5f / MATH_FUNC_F(sqrt, x);
} }
}; };
template<> template<>
struct SquareFunctor<float> { struct SquareFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x * x; } static OF_DEVICE_FUNC float Forward(const float x) { return x * x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * x; } static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * x; }
}; };
template<> template<>
struct TanFunctor<float> { struct TanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(tan, x); } static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(tan, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (MATH_FUNC_F(cos, x) * MATH_FUNC_F(cos, x))); return dy * (1.0f / (MATH_FUNC_F(cos, x) * MATH_FUNC_F(cos, x)));
} }
}; };
// double version // double version
template<> template<>
struct AcosFunctor<double> { struct AcosFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acos, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acos, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -RsqrtFunctor<double>::Forward(1.0 - x * x); return dy * -RsqrtFunctor<double>::Forward(1.0 - x * x);
} }
}; };
template<> template<>
struct AcoshFunctor<double> { struct AcoshFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acosh, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acosh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -RsqrtFunctor<double>::Forward(x * x - 1.0); return dy * -RsqrtFunctor<double>::Forward(x * x - 1.0);
} }
}; };
template<> template<>
struct AsinFunctor<double> { struct AsinFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asin, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asin, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * RsqrtFunctor<double>::Forward(1.0 - x * x); return dy * RsqrtFunctor<double>::Forward(1.0 - x * x);
} }
}; };
template<> template<>
struct AsinhFunctor<double> { struct AsinhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asinh, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asinh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * RsqrtFunctor<double>::Forward(1.0 + x * x); return dy * RsqrtFunctor<double>::Forward(1.0 + x * x);
} }
}; };
template<> template<>
struct AtanFunctor<double> { struct AtanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atan, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atan, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (1.0 + x * x)); return dy * (1.0 / (1.0 + x * x));
} }
}; };
template<> template<>
struct AtanhFunctor<double> { struct AtanhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atanh, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atanh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (1.0 - x * x)); return dy * (1.0 / (1.0 - x * x));
} }
}; };
template<> template<>
struct NotEqualZeroFunctor<double> { struct NotEqualZeroFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; } static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0f; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0f; }
}; };
template<> template<>
struct CeilFunctor<double> { struct CeilFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
}; };
template<> template<>
struct CosFunctor<double> { struct CosFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cos, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cos, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-MATH_FUNC_D(sin, x)); return dy * (-MATH_FUNC_D(sin, x));
} }
}; };
template<> template<>
struct CoshFunctor<double> { struct CoshFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cosh, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cosh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(sinh, x); return dy * MATH_FUNC_D(sinh, x);
} }
}; };
template<> template<>
struct ErfFunctor<double> { struct ErfFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erf, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erf, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * 2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x); return dy * 2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);
} }
}; };
template<> template<>
struct ErfcFunctor<double> { struct ErfcFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erfc, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erfc, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x); return dy * -2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);
} }
}; };
template<> template<>
struct ExpFunctor<double> { struct ExpFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(exp, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(exp, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(exp, x); return dy * MATH_FUNC_D(exp, x);
} }
}; };
template<> template<>
struct Expm1Functor<double> { struct Expm1Functor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(expm1, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(expm1, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(exp, x); return dy * MATH_FUNC_D(exp, x);
} }
}; };
template<> template<>
struct FloorFunctor<double> { struct FloorFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(floor, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(floor, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
}; };
template<> template<>
struct LgammaFunctor<double> { struct LgammaFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(lgamma, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(lgamma, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
// TODO(chengcheng): return: dy * digamma(x) // TODO(chengcheng): return: dy * digamma(x)
assert(false); // assert(false);
return 0.0; return 0.0;
} }
}; };
template<> template<>
struct LogFunctor<double> { struct LogFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); }
}; };
template<> template<>
struct Log2Functor<double> { struct Log2Functor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0))); return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0)));
} }
}; };
template<> template<>
struct Log1pFunctor<double> { struct Log1pFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (x + 1.0)); return dy * (1.0 / (x + 1.0));
} }
}; };
template<> template<>
struct LogSigmoidFunctor<double> { struct LogSigmoidFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { static OF_DEVICE_FUNC double Forward(const double x) {
return -MATH_FUNC_D(log, (1.0 + MATH_FUNC_D(exp, -x))); return -MATH_FUNC_D(log, (1.0 + MATH_FUNC_D(exp, -x)));
} }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (MATH_FUNC_D(exp, x) + 1.0)); return dy * (1.0 / (MATH_FUNC_D(exp, x) + 1.0));
} }
}; };
template<> template<>
struct NegativeFunctor<double> { struct NegativeFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return -x; } static OF_DEVICE_FUNC double Forward(const double x) { return -x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return -dy; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return -dy; }
}; };
template<> template<>
struct ReciprocalFunctor<double> { struct ReciprocalFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / x; } static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-1.0 / (x * x)); return dy * (-1.0 / (x * x));
} }
}; };
template<> template<>
struct ReciprocalNoNanFunctor<double> { struct ReciprocalNoNanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { static OF_DEVICE_FUNC double Forward(const double x) {
if (fabs(x) <= 0.0) { return 0.0; } if (fabs(x) <= 0.0) { return 0.0; }
return 1.0 / x; return 1.0 / x;
} }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
if (fabs(x) <= 0.0) { return 0.0; } if (fabs(x) <= 0.0) { return 0.0; }
return dy * (-1.0 / (x * x)); return dy * (-1.0 / (x * x));
} }
}; };
template<> template<>
struct RintFunctor<double> { struct RintFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(rint, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(rint, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
}; };
template<> template<>
struct RoundFunctor<double> { struct RoundFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(nearbyint, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(nearbyint, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
}; };
template<> template<>
struct SigmoidFunctor<double> { struct SigmoidFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { static OF_DEVICE_FUNC double Forward(const double x) {
return 1.0 / (1.0 + MATH_FUNC_D(exp, -x)); return 1.0 / (1.0 + MATH_FUNC_D(exp, -x));
} }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
double y = 1.0 / (1.0 + MATH_FUNC_D(exp, -x)); double y = 1.0 / (1.0 + MATH_FUNC_D(exp, -x));
return dy * (y * (1.0 - y)); return dy * (y * (1.0 - y));
} }
}; };
template<> template<>
struct SinFunctor<double> { struct SinFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sin, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sin, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(cos, x); return dy * MATH_FUNC_D(cos, x);
} }
}; };
template<> template<>
struct SinhFunctor<double> { struct SinhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sinh, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sinh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(cosh, x); return dy * MATH_FUNC_D(cosh, x);
} }
}; };
template<> template<>
struct SqrtFunctor<double> { struct SqrtFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sqrt, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sqrt, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (double)0.5 / MATH_FUNC_D(sqrt, x); return dy * (double)0.5 / MATH_FUNC_D(sqrt, x);
} }
}; };
template<> template<>
struct SquareFunctor<double> { struct SquareFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x * x; } static OF_DEVICE_FUNC double Forward(const double x) { return x * x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * x; } static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * x; }
}; };
template<> template<>
struct TanFunctor<double> { struct TanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(tan, x); } static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(tan, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (MATH_FUNC_D(cos, x) * MATH_FUNC_D(cos, x))); return dy * (1.0 / (MATH_FUNC_D(cos, x) * MATH_FUNC_D(cos, x)));
} }
}; };
#if defined(__CUDACC__) || defined(__HIPCC__) #if defined(__CUDACC__) || defined(__HIPCC__)
// half version // half version
#define OF_HALF_FUNC __device__ __forceinline__ #define OF_HALF_FUNC __device__ __forceinline__
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x))) #define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define HALF_VAL_HALF __float2half(0.5f) #define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_TWO __float2half(2.0f) #define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f) #define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
template<> template<>
struct AbsFunctor<half> { struct AbsFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { static OF_HALF_FUNC half Forward(const half x) {
return __hlt(x, GetZeroVal<half>()) ? __hneg(x) : x; return __hlt(x, GetZeroVal<half>()) ? __hneg(x) : x;
} }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hlt(x, GetZeroVal<half>()) ? __hneg(dy) : dy; return __hlt(x, GetZeroVal<half>()) ? __hneg(dy) : dy;
} }
}; };
template<> template<>
struct AcosFunctor<half> { struct AcosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acos, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acos, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x))))); return __hmul(dy, __hneg(hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x)))));
} }
}; };
template<> template<>
struct AcoshFunctor<half> { struct AcoshFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acosh, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acosh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hsub(__hmul(x, x), GetOneVal<half>()))); return __hmul(dy, hrsqrt(__hsub(__hmul(x, x), GetOneVal<half>())));
} }
}; };
template<> template<>
struct AsinFunctor<half> { struct AsinFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asin, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asin, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x)))); return __hmul(dy, hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x))));
} }
}; };
template<> template<>
struct AsinhFunctor<half> { struct AsinhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asinh, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asinh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hadd(GetOneVal<half>(), __hmul(x, x)))); return __hmul(dy, hrsqrt(__hadd(GetOneVal<half>(), __hmul(x, x))));
} }
}; };
template<> template<>
struct AtanFunctor<half> { struct AtanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atan, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atan, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(GetOneVal<half>(), __hadd(GetOneVal<half>(), __hmul(x, x)))); return __hmul(dy, __hdiv(GetOneVal<half>(), __hadd(GetOneVal<half>(), __hmul(x, x))));
} }
}; };
template<> template<>
struct AtanhFunctor<half> { struct AtanhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atanh, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atanh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(GetOneVal<half>(), __hsub(GetOneVal<half>(), __hmul(x, x)))); return __hmul(dy, __hdiv(GetOneVal<half>(), __hsub(GetOneVal<half>(), __hmul(x, x))));
} }
}; };
template<> template<>
struct CeilFunctor<half> { struct CeilFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hceil(x); } static OF_HALF_FUNC half Forward(const half x) { return hceil(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
}; };
template<> template<>
struct NotEqualZeroFunctor<half> { struct NotEqualZeroFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return x != static_cast<half>(0.0); } static OF_HALF_FUNC half Forward(const half x) { return x != static_cast<half>(0.0); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
}; };
template<> template<>
struct CosFunctor<half> { struct CosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hcos(x); } static OF_HALF_FUNC half Forward(const half x) { return hcos(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hsin(x))); return __hmul(dy, __hneg(hsin(x)));
} }
}; };
template<> template<>
struct CoshFunctor<half> { struct CoshFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(cosh, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(cosh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, MATH_FUNC_H(sinh, x)); return __hmul(dy, MATH_FUNC_H(sinh, x));
} }
}; };
template<> template<>
struct ErfFunctor<half> { struct ErfFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erf, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erf, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x)))); return __hmul(dy, __hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x))));
} }
}; };
template<> template<>
struct ErfcFunctor<half> { struct ErfcFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erfc, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erfc, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(__hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x))))); return __hmul(dy, __hneg(__hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x)))));
} }
}; };
template<> template<>
struct ExpFunctor<half> { struct ExpFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hexp(x); } static OF_HALF_FUNC half Forward(const half x) { return hexp(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }
}; };
template<> template<>
struct Expm1Functor<half> { struct Expm1Functor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(expm1, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(expm1, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }
}; };
template<> template<>
struct FloorFunctor<half> { struct FloorFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hfloor(x); } static OF_HALF_FUNC half Forward(const half x) { return hfloor(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
}; };
template<> template<>
struct LgammaFunctor<half> { struct LgammaFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(lgamma, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(lgamma, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
// TODO(chengcheng): return: dy * digamma(x) // TODO(chengcheng): return: dy * digamma(x)
assert(false); // assert(false);
return GetZeroVal<half>(); return GetZeroVal<half>();
} }
}; };
template<> template<>
struct LogFunctor<half> { struct LogFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hlog(x); } static OF_HALF_FUNC half Forward(const half x) { return hlog(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); }
}; };
template<> template<>
struct Log2Functor<half> { struct Log2Functor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); } static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO)))); return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO))));
} }
}; };
template<> template<>
struct Log1pFunctor<half> { struct Log1pFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hadd(x, GetOneVal<half>()))); return __hmul(dy, hrcp(__hadd(x, GetOneVal<half>())));
} }
}; };
template<> template<>
struct LogSigmoidFunctor<half> { struct LogSigmoidFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { static OF_HALF_FUNC half Forward(const half x) {
return __hneg(hlog(__hadd(GetOneVal<half>(), hexp(__hneg(x))))); return __hneg(hlog(__hadd(GetOneVal<half>(), hexp(__hneg(x)))));
} }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hadd(hexp(x), GetOneVal<half>()))); return __hmul(dy, hrcp(__hadd(hexp(x), GetOneVal<half>())));
} }
}; };
template<> template<>
struct NegativeFunctor<half> { struct NegativeFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hneg(x); } static OF_HALF_FUNC half Forward(const half x) { return __hneg(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hneg(dy); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hneg(dy); }
}; };
template<> template<>
struct ReciprocalFunctor<half> { struct ReciprocalFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrcp(x); } static OF_HALF_FUNC half Forward(const half x) { return hrcp(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrcp(__hmul(x, x)))); return __hmul(dy, __hneg(hrcp(__hmul(x, x))));
} }
}; };
template<> template<>
struct ReciprocalNoNanFunctor<half> { struct ReciprocalNoNanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { static OF_HALF_FUNC half Forward(const half x) {
if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); } if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }
return hrcp(x); return hrcp(x);
} }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); } if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }
return __hmul(dy, __hneg(hrcp(__hmul(x, x)))); return __hmul(dy, __hneg(hrcp(__hmul(x, x))));
} }
}; };
template<> template<>
struct RintFunctor<half> { struct RintFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrint(x); } static OF_HALF_FUNC half Forward(const half x) { return hrint(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
}; };
template<> template<>
struct RoundFunctor<half> { struct RoundFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(nearbyint, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(nearbyint, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
}; };
template<> template<>
struct RsqrtFunctor<half> { struct RsqrtFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrsqrt(x); } static OF_HALF_FUNC half Forward(const half x) { return hrsqrt(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrcp(__hmul(HALF_VAL_TWO, hsqrt(__hmul(x, __hmul(x, x))))))); return __hmul(dy, __hneg(hrcp(__hmul(HALF_VAL_TWO, hsqrt(__hmul(x, __hmul(x, x)))))));
} }
}; };
template<> template<>
struct SigmoidFunctor<half> { struct SigmoidFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { static OF_HALF_FUNC half Forward(const half x) {
return hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x)))); return hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));
} }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
half y = hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x)))); half y = hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));
return __hmul(dy, __hmul(y, __hsub(GetOneVal<half>(), y))); return __hmul(dy, __hmul(y, __hsub(GetOneVal<half>(), y)));
} }
}; };
template<> template<>
struct SignFunctor<half> { struct SignFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { static OF_HALF_FUNC half Forward(const half x) {
if (__hgt(x, GetZeroVal<half>())) { return GetOneVal<half>(); } if (__hgt(x, GetZeroVal<half>())) { return GetOneVal<half>(); }
if (__hlt(x, GetZeroVal<half>())) { return __hneg(GetOneVal<half>()); } if (__hlt(x, GetZeroVal<half>())) { return __hneg(GetOneVal<half>()); }
return GetZeroVal<half>(); return GetZeroVal<half>();
} }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
}; };
template<> template<>
struct SinFunctor<half> { struct SinFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hsin(x); } static OF_HALF_FUNC half Forward(const half x) { return hsin(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hcos(x)); } static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hcos(x)); }
}; };
template<> template<>
struct SinhFunctor<half> { struct SinhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(sinh, x); } static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(sinh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, MATH_FUNC_H(cosh, x)); return __hmul(dy, MATH_FUNC_H(cosh, x));
} }
}; };
template<> template<>
struct SqrtFunctor<half> { struct SqrtFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hsqrt(x); } static OF_HALF_FUNC half Forward(const half x) { return hsqrt(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(HALF_VAL_HALF, hsqrt(x))); return __hmul(dy, __hdiv(HALF_VAL_HALF, hsqrt(x)));
} }
}; };
template<> template<>
struct SquareFunctor<half> { struct SquareFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hmul(x, x); } static OF_HALF_FUNC half Forward(const half x) { return __hmul(x, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hmul(HALF_VAL_TWO, x)); return __hmul(dy, __hmul(HALF_VAL_TWO, x));
} }
}; };
template<> template<>
struct TanFunctor<half> { struct TanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hdiv(hsin(x), hcos(x)); } static OF_HALF_FUNC half Forward(const half x) { return __hdiv(hsin(x), hcos(x)); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hmul(hcos(x), hcos(x)))); return __hmul(dy, hrcp(__hmul(hcos(x), hcos(x))));
} }
}; };
#endif #endif
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_ #endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#ifdef OF_ENABLE_PROFILER
#include <roctracer_roctx.h>
#endif // OF_ENABLE_PROFILER
namespace oneflow {
namespace {
#ifdef OF_ENABLE_PROFILER
static thread_local HashMap<std::string, roctx_range_id_t> mark2range_id;
#endif
} // namespace
class NvtxOpKernelState final : public user_op::OpKernelState {
public:
NvtxOpKernelState() : counter_(0) {
#ifndef OF_ENABLE_PROFILER
LOG(WARNING) << "To use NVTX, run cmake with -DBUILD_PROFILER=ON";
#endif
}
~NvtxOpKernelState() override = default;
int64_t counter() const { return counter_; }
void IncreaseCount() { counter_ += 1; }
private:
int64_t counter_;
};
class NvtxStartKernel final : public user_op::OpKernel {
public:
NvtxStartKernel() = default;
~NvtxStartKernel() override = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
return std::make_shared<NvtxOpKernelState>();
}
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,
const user_op::OpKernelCache*) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const ShapeView& in_shape = in->shape_view();
CHECK_EQ(out->shape_view(), in_shape);
const DataType in_data_type = in->data_type();
CHECK_EQ(out->data_type(), in_data_type);
Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<void>(), in->dptr<void>(),
in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));
#ifdef OF_ENABLE_PROFILER
auto* kernel_state = dynamic_cast<NvtxOpKernelState*>(state);
const std::string mark_prefix = ctx->Attr<std::string>("mark_prefix");
const std::string mark = mark_prefix + "-" + std::to_string(kernel_state->counter());
roctx_range_id_t range_id = roctxRangeStartA(mark.c_str());
CHECK(mark2range_id.emplace(mark, range_id).second);
kernel_state->IncreaseCount();
#endif // OF_ENABLE_PROFILER
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("nvtx_start")
.SetCreateFn<NvtxStartKernel>()
.SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)
.SetInplaceProposalFn([](const user_op::InferContext&,
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false));
return Maybe<void>::Ok();
});
class NvtxEndKernel final : public user_op::OpKernel {
public:
NvtxEndKernel() = default;
~NvtxEndKernel() override = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
return std::make_shared<NvtxOpKernelState>();
}
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,
const user_op::OpKernelCache*) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const ShapeView& in_shape = in->shape_view();
CHECK_EQ(out->shape_view(), in_shape);
const DataType in_data_type = in->data_type();
CHECK_EQ(out->data_type(), in_data_type);
#ifdef OF_ENABLE_PROFILER
auto* kernel_state = dynamic_cast<NvtxOpKernelState*>(state);
const std::string mark_prefix = ctx->Attr<std::string>("mark_prefix");
const std::string mark = mark_prefix + "-" + std::to_string(kernel_state->counter());
auto it = mark2range_id.find(mark.c_str());
CHECK(it != mark2range_id.end());
roctx_range_id_t range_id = it->second;
mark2range_id.erase(it);
roctxRangeStop(range_id);
Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<void>(), in->dptr<void>(),
in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));
kernel_state->IncreaseCount();
#endif
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("nvtx_end")
.SetCreateFn<NvtxEndKernel>()
.SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)
.SetInplaceProposalFn([](const user_op::InferContext&,
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false));
return Maybe<void>::Ok();
});
} // namespace oneflow
/* /*
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/user/kernels/stateful_opkernel.h" #include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/rpc/include/global_process_ctx.h" #include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h" #include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/operator.h"
#include "oneflow/core/profiler/profiler.h" #include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h" #include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h" #include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/eager/call_context.h" #include "oneflow/core/eager/call_context.h"
namespace oneflow { namespace oneflow {
namespace one { namespace one {
class ConsistentTensorInferResult; class ConsistentTensorInferResult;
using ArgVec = std::vector<std::pair<std::string, int32_t>>; using ArgVec = std::vector<std::pair<std::string, int32_t>>;
using EagerBlobObjectListRawPtr = const std::vector<std::shared_ptr<vm::EagerBlobObject>>*; using EagerBlobObjectListRawPtr = const std::vector<std::shared_ptr<vm::EagerBlobObject>>*;
using ConsistentTensorInferResultRawPtr = const ConsistentTensorInferResult*; using ConsistentTensorInferResultRawPtr = const ConsistentTensorInferResult*;
class ZeroCopyBaseContextHelper { class ZeroCopyBaseContextHelper {
public: public:
ZeroCopyBaseContextHelper(const std::shared_ptr<const ArgTuple>& input_arg_tuple, ZeroCopyBaseContextHelper(const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {} : input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {}
#define RETURN_IF_FOUND(inputs, outputs, post_action) \ #define RETURN_IF_FOUND(inputs, outputs, post_action) \
int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \ int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \
arg_name, index); \ arg_name, index); \
if (i >= 0) { return (inputs).at(i) post_action; } \ if (i >= 0) { return (inputs).at(i) post_action; } \
i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \ i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \
index); \ index); \
if (i >= 0) { return (outputs).at(i) post_action; } if (i >= 0) { return (outputs).at(i) post_action; }
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
const int32_t index) const { const int32_t index) const {
RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get()); RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get());
return nullptr; return nullptr;
} }
user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
const int32_t index) const { const int32_t index) const {
RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get()); RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get());
if (arg_name == "tmp_buffer" && index == 0) { return call_ctx->mut_tmp_tensor(); } if (arg_name == "tmp_buffer" && index == 0) { return call_ctx->mut_tmp_tensor(); }
return nullptr; return nullptr;
} }
const ConsistentTensorMeta* ConsistentTensorMeta4ArgNameAndIndex(eager::CallContext* call_ctx, const ConsistentTensorMeta* ConsistentTensorMeta4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
const int32_t index) const { const int32_t index) const {
const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result(); const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result();
RETURN_IF_FOUND(consistent_tensor_infer_result->input_tensor_metas(), RETURN_IF_FOUND(consistent_tensor_infer_result->input_tensor_metas(),
consistent_tensor_infer_result->output_tensor_metas(), consistent_tensor_infer_result->output_tensor_metas(),
.shared_from_symbol().get()); .shared_from_symbol().get());
return nullptr; return nullptr;
} }
Optional<Symbol<ParallelDesc>> parallel_desc(eager::CallContext* call_ctx) const { Optional<Symbol<ParallelDesc>> parallel_desc(eager::CallContext* call_ctx) const {
const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result(); const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result();
if (!consistent_tensor_infer_result) { return Optional<Symbol<ParallelDesc>>(); } if (!consistent_tensor_infer_result) { return Optional<Symbol<ParallelDesc>>(); }
if (!consistent_tensor_infer_result->input_tensor_metas().empty()) { if (!consistent_tensor_infer_result->input_tensor_metas().empty()) {
return consistent_tensor_infer_result->input_tensor_metas().at(0)->parallel_desc(); return consistent_tensor_infer_result->input_tensor_metas().at(0)->parallel_desc();
} else if (!consistent_tensor_infer_result->output_tensor_metas().empty()) { } else if (!consistent_tensor_infer_result->output_tensor_metas().empty()) {
return consistent_tensor_infer_result->output_tensor_metas().at(0)->parallel_desc(); return consistent_tensor_infer_result->output_tensor_metas().at(0)->parallel_desc();
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
return Optional<Symbol<ParallelDesc>>(); return Optional<Symbol<ParallelDesc>>();
} }
} }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
const auto& parallel_desc = this->parallel_desc(call_ctx); const auto& parallel_desc = this->parallel_desc(call_ctx);
if (parallel_desc.has_value()) { if (parallel_desc.has_value()) {
const auto& parallel_desc_symbol = CHECK_JUST(parallel_desc); const auto& parallel_desc_symbol = CHECK_JUST(parallel_desc);
return *CHECK_JUST(GetParallelContext4CurrentProcessCtx(parallel_desc_symbol)); return *CHECK_JUST(GetParallelContext4CurrentProcessCtx(parallel_desc_symbol));
} else { } else {
static ParallelContext single_device_parallel_ctx(MakeSingleDeviceParallelCtx()); static ParallelContext single_device_parallel_ctx(MakeSingleDeviceParallelCtx());
return single_device_parallel_ctx; return single_device_parallel_ctx;
} }
} }
const ArgVec& inputs() const { return input_arg_tuple_->indexed_arg_name_and_index(); } const ArgVec& inputs() const { return input_arg_tuple_->indexed_arg_name_and_index(); }
const ArgVec& outputs() const { return output_arg_tuple_->indexed_arg_name_and_index(); } const ArgVec& outputs() const { return output_arg_tuple_->indexed_arg_name_and_index(); }
private: private:
static int32_t TryGetTensorTupleIndex(const std::unordered_map<std::string, std::vector<int32_t>>& static int32_t TryGetTensorTupleIndex(const std::unordered_map<std::string, std::vector<int32_t>>&
arg_name2bn_index2tensor_tuple_index, arg_name2bn_index2tensor_tuple_index,
const std::string& arg_name, const int32_t arg_index) { const std::string& arg_name, const int32_t arg_index) {
auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name); auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name);
if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); } if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); }
return -1; return -1;
} }
static ParallelContext MakeSingleDeviceParallelCtx() { static ParallelContext MakeSingleDeviceParallelCtx() {
ParallelContext single_device_parallel_ctx; ParallelContext single_device_parallel_ctx;
single_device_parallel_ctx.set_parallel_id(0); single_device_parallel_ctx.set_parallel_id(0);
single_device_parallel_ctx.set_parallel_num(1); single_device_parallel_ctx.set_parallel_num(1);
return single_device_parallel_ctx; return single_device_parallel_ctx;
} }
std::shared_ptr<const ArgTuple> input_arg_tuple_; std::shared_ptr<const ArgTuple> input_arg_tuple_;
std::shared_ptr<const ArgTuple> output_arg_tuple_; std::shared_ptr<const ArgTuple> output_arg_tuple_;
}; };
class UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper { class UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper {
public: public:
UserKernelBaseContextHelper(DeviceType device_type, UserKernelBaseContextHelper(DeviceType device_type,
const std::shared_ptr<const ArgTuple>& input_arg_tuple, const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {} : ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {}
~UserKernelBaseContextHelper() = default; ~UserKernelBaseContextHelper() = default;
DeviceType device_type() const { return device_type_; } DeviceType device_type() const { return device_type_; }
const JobDesc& job_desc() const { const JobDesc& job_desc() const {
UNIMPLEMENTED(); UNIMPLEMENTED();
return *(const JobDesc*)nullptr; return *(const JobDesc*)nullptr;
} }
private: private:
const DeviceType device_type_; const DeviceType device_type_;
}; };
class UserOpInferContextHelper final { class UserOpInferContextHelper final {
public: public:
UserOpInferContextHelper(const user_op::UserOpConfWrapper* user_op_conf, UserOpInferContextHelper(const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple, const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf), : user_op_conf_(user_op_conf),
zero_copy_base_ctx_helper_(input_arg_tuple, output_arg_tuple) {} zero_copy_base_ctx_helper_(input_arg_tuple, output_arg_tuple) {}
~UserOpInferContextHelper() = default; ~UserOpInferContextHelper() = default;
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx, const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const { const std::string& arg_name, int32_t index) const {
return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)); return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index));
} }
user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name, user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
} }
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
} }
const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name, const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return *Shape4ArgNameAndIndex(call_ctx, arg_name, index); return *Shape4ArgNameAndIndex(call_ctx, arg_name, index);
} }
Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name, Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return Shape4ArgNameAndIndex(call_ctx, arg_name, index); return Shape4ArgNameAndIndex(call_ctx, arg_name, index);
} }
Shape* Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, Shape* Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape(); return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape();
} }
const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name, const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return *Stride4ArgNameAndIndex(call_ctx, arg_name, index); return *Stride4ArgNameAndIndex(call_ctx, arg_name, index);
} }
Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name, Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return Stride4ArgNameAndIndex(call_ctx, arg_name, index); return Stride4ArgNameAndIndex(call_ctx, arg_name, index);
} }
Stride* Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, Stride* Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride(); return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride();
} }
const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name, const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return *Dtype4ArgNameAndIndex(call_ctx, arg_name, index); return *Dtype4ArgNameAndIndex(call_ctx, arg_name, index);
} }
DataType* OutputDType(eager::CallContext* call_ctx, const std::string& arg_name, DataType* OutputDType(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return Dtype4ArgNameAndIndex(call_ctx, arg_name, index); return Dtype4ArgNameAndIndex(call_ctx, arg_name, index);
} }
DataType* Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, DataType* Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type(); return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type();
} }
bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return *IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); return *IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);
} }
bool* OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name, bool* OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index); return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);
} }
bool* IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, bool* IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic(); return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic();
} }
const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); } const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return zero_copy_base_ctx_helper_.outputs(); } const ArgVec& outputs() const { return zero_copy_base_ctx_helper_.outputs(); }
const JobDesc* job_desc() const { const JobDesc* job_desc() const {
UNIMPLEMENTED(); UNIMPLEMENTED();
return nullptr; return nullptr;
} }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return zero_copy_base_ctx_helper_.parallel_ctx(call_ctx); return zero_copy_base_ctx_helper_.parallel_ctx(call_ctx);
} }
const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const { const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {
return *CHECK_JUST(zero_copy_base_ctx_helper_.parallel_desc(call_ctx)); return *CHECK_JUST(zero_copy_base_ctx_helper_.parallel_desc(call_ctx));
} }
const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx, const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const { const std::string& arg_name, int32_t index) const {
const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index); const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);
CHECK_EQ(nd_sbp.sbp_parallel_size(), 1); CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);
return nd_sbp.sbp_parallel(0); return nd_sbp.sbp_parallel(0);
} }
const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return *CHECK_NOTNULL(zero_copy_base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex( return *CHECK_NOTNULL(zero_copy_base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(
call_ctx, arg_name, index)) call_ctx, arg_name, index))
->nd_sbp(); ->nd_sbp();
} }
int64_t parallel_num(eager::CallContext* call_ctx) const { int64_t parallel_num(eager::CallContext* call_ctx) const {
return parallel_ctx(call_ctx).parallel_num(); return parallel_ctx(call_ctx).parallel_num();
} }
const std::string& input(const std::string& arg_name, int32_t index) const { const std::string& input(const std::string& arg_name, int32_t index) const {
return user_op_conf().input(arg_name, index); return user_op_conf().input(arg_name, index);
} }
const std::string& output(const std::string& arg_name, int32_t index) const { const std::string& output(const std::string& arg_name, int32_t index) const {
return user_op_conf().output(arg_name, index); return user_op_conf().output(arg_name, index);
} }
bool has_input(const std::string& arg_name, int32_t index) const { bool has_input(const std::string& arg_name, int32_t index) const {
return user_op_conf().has_input(arg_name, index); return user_op_conf().has_input(arg_name, index);
} }
bool has_output(const std::string& arg_name, int32_t index) const { bool has_output(const std::string& arg_name, int32_t index) const {
return user_op_conf().has_output(arg_name, index); return user_op_conf().has_output(arg_name, index);
} }
int32_t input_size(const std::string& arg_name) const { int32_t input_size(const std::string& arg_name) const {
return user_op_conf().input_size(arg_name); return user_op_conf().input_size(arg_name);
} }
int32_t output_size(const std::string& arg_name) const { int32_t output_size(const std::string& arg_name) const {
return user_op_conf().output_size(arg_name); return user_op_conf().output_size(arg_name);
} }
const std::string& op_name() const { return user_op_conf().op_name(); } const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); } const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); } const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx, const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const { const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name); return call_ctx->composed_attrs().Attr4Name(attr_name);
} }
private: private:
user_op::TensorDesc* NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, user_op::TensorDesc* NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; } if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; }
return tensor_desc; return tensor_desc;
} }
const user_op::UserOpConfWrapper* user_op_conf_; const user_op::UserOpConfWrapper* user_op_conf_;
ZeroCopyBaseContextHelper zero_copy_base_ctx_helper_; ZeroCopyBaseContextHelper zero_copy_base_ctx_helper_;
}; };
class UserOpInferContext : public user_op::InferContext { class UserOpInferContext : public user_op::InferContext {
public: public:
UserOpInferContext(const UserOpInferContextHelper* helper, eager::CallContext* call_ctx) UserOpInferContext(const UserOpInferContextHelper* helper, eager::CallContext* call_ctx)
: helper_(helper), call_ctx_(call_ctx) {} : helper_(helper), call_ctx_(call_ctx) {}
~UserOpInferContext() override = default; ~UserOpInferContext() override = default;
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name, const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->InputTensorDesc(call_ctx_, arg_name, index); return helper_->InputTensorDesc(call_ctx_, arg_name, index);
} }
user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override {
return helper_->OutputTensorDesc(call_ctx_, arg_name, index); return helper_->OutputTensorDesc(call_ctx_, arg_name, index);
} }
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const Shape& InputShape(const std::string& arg_name, int32_t index) const override { const Shape& InputShape(const std::string& arg_name, int32_t index) const override {
return helper_->InputShape(call_ctx_, arg_name, index); return helper_->InputShape(call_ctx_, arg_name, index);
} }
Shape* OutputShape(const std::string& arg_name, int32_t index) override { Shape* OutputShape(const std::string& arg_name, int32_t index) override {
return helper_->OutputShape(call_ctx_, arg_name, index); return helper_->OutputShape(call_ctx_, arg_name, index);
} }
Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const Stride& InputStride(const std::string& arg_name, int32_t index) const override { const Stride& InputStride(const std::string& arg_name, int32_t index) const override {
return helper_->InputStride(call_ctx_, arg_name, index); return helper_->InputStride(call_ctx_, arg_name, index);
} }
Stride* OutputStride(const std::string& arg_name, int32_t index) override { Stride* OutputStride(const std::string& arg_name, int32_t index) override {
return helper_->OutputStride(call_ctx_, arg_name, index); return helper_->OutputStride(call_ctx_, arg_name, index);
} }
Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const DataType& InputDType(const std::string& arg_name, int32_t index) const override { const DataType& InputDType(const std::string& arg_name, int32_t index) const override {
return helper_->InputDType(call_ctx_, arg_name, index); return helper_->InputDType(call_ctx_, arg_name, index);
} }
DataType* OutputDType(const std::string& arg_name, int32_t index) override { DataType* OutputDType(const std::string& arg_name, int32_t index) override {
return helper_->OutputDType(call_ctx_, arg_name, index); return helper_->OutputDType(call_ctx_, arg_name, index);
} }
DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
bool InputIsDynamic(const std::string& arg_name, int32_t index) const override { bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {
return helper_->InputIsDynamic(call_ctx_, arg_name, index); return helper_->InputIsDynamic(call_ctx_, arg_name, index);
} }
bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override { bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override {
return helper_->OutputIsDynamic(call_ctx_, arg_name, index); return helper_->OutputIsDynamic(call_ctx_, arg_name, index);
} }
bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); } const ArgVec& outputs() const override { return helper_->outputs(); }
const JobDesc* job_desc() const override { return helper_->job_desc(); } const JobDesc* job_desc() const override { return helper_->job_desc(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); } const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }
const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {
return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
int64_t parallel_num() const override { return helper_->parallel_num(call_ctx_); } int64_t parallel_num() const override { return helper_->parallel_num(call_ctx_); }
const std::string& input(const std::string& arg_name, int32_t index) const override { const std::string& input(const std::string& arg_name, int32_t index) const override {
return helper_->input(arg_name, index); return helper_->input(arg_name, index);
} }
const std::string& output(const std::string& arg_name, int32_t index) const override { const std::string& output(const std::string& arg_name, int32_t index) const override {
return helper_->output(arg_name, index); return helper_->output(arg_name, index);
} }
bool has_input(const std::string& arg_name, int32_t index) const override { bool has_input(const std::string& arg_name, int32_t index) const override {
return helper_->has_input(arg_name, index); return helper_->has_input(arg_name, index);
} }
bool has_output(const std::string& arg_name, int32_t index) const override { bool has_output(const std::string& arg_name, int32_t index) const override {
return helper_->has_output(arg_name, index); return helper_->has_output(arg_name, index);
} }
int32_t input_size(const std::string& arg_name) const override { int32_t input_size(const std::string& arg_name) const override {
return helper_->input_size(arg_name); return helper_->input_size(arg_name);
} }
int32_t output_size(const std::string& arg_name) const override { int32_t output_size(const std::string& arg_name) const override {
return helper_->output_size(arg_name); return helper_->output_size(arg_name);
} }
const std::string& op_name() const override { return helper_->op_name(); } const std::string& op_name() const override { return helper_->op_name(); }
const std::string& op_type_name() const override { return helper_->op_type_name(); } const std::string& op_type_name() const override { return helper_->op_type_name(); }
const std::string& op_loc() const override { return helper_->op_loc(); } const std::string& op_loc() const override { return helper_->op_loc(); }
private: private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name( const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override { const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name); return helper_->Attr4Name(call_ctx_, attr_name);
} }
const UserOpInferContextHelper* helper_; const UserOpInferContextHelper* helper_;
eager::CallContext* call_ctx_; eager::CallContext* call_ctx_;
}; };
class UserKernelComputeContextHelper final { class UserKernelComputeContextHelper final {
public: public:
UserKernelComputeContextHelper(DeviceType device_type, UserKernelComputeContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf, const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple, const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf), : user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {} base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelComputeContextHelper() = default; ~UserKernelComputeContextHelper() = default;
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
} }
user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index); return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index);
} }
ep::Stream* stream(DeviceCtx* device_ctx) const { ep::Stream* stream(DeviceCtx* device_ctx) const {
CHECK(device_ctx); CHECK(device_ctx);
return device_ctx->stream(); return device_ctx->stream();
} }
DeviceType device_type() const { return base_ctx_helper_.device_type(); } DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx); return base_ctx_helper_.parallel_ctx(call_ctx);
} }
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx, const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const { const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name); return call_ctx->composed_attrs().Attr4Name(attr_name);
} }
private: private:
const user_op::UserOpConfWrapper* user_op_conf_; const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_; UserKernelBaseContextHelper base_ctx_helper_;
}; };
class UserKernelComputeContext final : public user_op::KernelComputeContext { class UserKernelComputeContext final : public user_op::KernelComputeContext {
public: public:
UserKernelComputeContext(const UserKernelComputeContextHelper* helper, UserKernelComputeContext(const UserKernelComputeContextHelper* helper,
eager::CallContext* call_ctx, DeviceCtx* device_ctx) eager::CallContext* call_ctx, DeviceCtx* device_ctx)
: helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {} : helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {}
~UserKernelComputeContext() = default; ~UserKernelComputeContext() = default;
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override { user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Tensor4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->Tensor4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
ep::Stream* stream() override { return helper_->stream(device_ctx_); } ep::Stream* stream() override { return helper_->stream(device_ctx_); }
DeviceType device_type() const override { return helper_->device_type(); } DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); } const ArgVec& outputs() const override { return helper_->outputs(); }
private: private:
const user_op::UserOpConfWrapper& user_op_conf() const override { const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf(); return helper_->user_op_conf();
} }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name( const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override { const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name); return helper_->Attr4Name(call_ctx_, attr_name);
} }
const UserKernelComputeContextHelper* helper_; const UserKernelComputeContextHelper* helper_;
eager::CallContext* call_ctx_; eager::CallContext* call_ctx_;
DeviceCtx* device_ctx_; DeviceCtx* device_ctx_;
}; };
class UserKernelRegContextHelper final { class UserKernelRegContextHelper final {
public: public:
UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf, UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple, const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf), : user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {} base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelRegContextHelper() = default; ~UserKernelRegContextHelper() = default;
DeviceType device_type() const { return base_ctx_helper_.device_type(); } DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx); return base_ctx_helper_.parallel_ctx(call_ctx);
} }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
} }
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx, const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const { const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name); return call_ctx->composed_attrs().Attr4Name(attr_name);
} }
private: private:
const user_op::UserOpConfWrapper* user_op_conf_; const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_; UserKernelBaseContextHelper base_ctx_helper_;
}; };
class UserKernelRegContext final : public user_op::KernelRegContext { class UserKernelRegContext final : public user_op::KernelRegContext {
public: public:
UserKernelRegContext(const UserKernelRegContextHelper* helper, eager::CallContext* call_ctx) UserKernelRegContext(const UserKernelRegContextHelper* helper, eager::CallContext* call_ctx)
: helper_(helper), call_ctx_(call_ctx) {} : helper_(helper), call_ctx_(call_ctx) {}
~UserKernelRegContext() = default; ~UserKernelRegContext() = default;
DeviceType device_type() const override { return helper_->device_type(); } DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); } const ArgVec& outputs() const override { return helper_->outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const override { const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf(); return helper_->user_op_conf();
} }
private: private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name( const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override { const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name); return helper_->Attr4Name(call_ctx_, attr_name);
} }
const UserKernelRegContextHelper* helper_; const UserKernelRegContextHelper* helper_;
eager::CallContext* call_ctx_; eager::CallContext* call_ctx_;
}; };
class UserKernelInitAndCacheContextHelper final { class UserKernelInitAndCacheContextHelper final {
public: public:
UserKernelInitAndCacheContextHelper(DeviceType device_type, UserKernelInitAndCacheContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf, const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple, const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf), : user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {} base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelInitAndCacheContextHelper() = default; ~UserKernelInitAndCacheContextHelper() = default;
ep::Stream* stream(DeviceCtx* device_ctx) const { ep::Stream* stream(DeviceCtx* device_ctx) const {
CHECK(device_ctx); CHECK(device_ctx);
return device_ctx->stream(); return device_ctx->stream();
} }
DeviceType device_type() const { return base_ctx_helper_.device_type(); } DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const { const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx); return base_ctx_helper_.parallel_ctx(call_ctx);
} }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index); return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
} }
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx, const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index); return base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index);
} }
const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx, const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const { const std::string& arg_name, int32_t index) const {
const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index); const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);
CHECK_EQ(nd_sbp.sbp_parallel_size(), 1); CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);
return nd_sbp.sbp_parallel(0); return nd_sbp.sbp_parallel(0);
} }
const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name, const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const { int32_t index) const {
return *CHECK_NOTNULL( return *CHECK_NOTNULL(
base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index)) base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index))
->nd_sbp(); ->nd_sbp();
} }
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); } const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); } const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const { const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {
return *CHECK_JUST(base_ctx_helper_.parallel_desc(call_ctx)); return *CHECK_JUST(base_ctx_helper_.parallel_desc(call_ctx));
} }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx, const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const { const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name); return call_ctx->composed_attrs().Attr4Name(attr_name);
} }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; } const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
private: private:
const user_op::UserOpConfWrapper* user_op_conf_; const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_; UserKernelBaseContextHelper base_ctx_helper_;
}; };
class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, class UserKernelInitAndCacheContext final : public user_op::KernelInitContext,
public user_op::KernelCacheContext { public user_op::KernelCacheContext {
public: public:
UserKernelInitAndCacheContext(const UserKernelInitAndCacheContextHelper* helper, UserKernelInitAndCacheContext(const UserKernelInitAndCacheContextHelper* helper,
eager::CallContext* call_ctx, DeviceCtx* device_ctx) eager::CallContext* call_ctx, DeviceCtx* device_ctx)
: helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {} : helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {}
~UserKernelInitAndCacheContext() override = default; ~UserKernelInitAndCacheContext() override = default;
ep::Stream* stream() override { return helper_->stream(device_ctx_); } ep::Stream* stream() override { return helper_->stream(device_ctx_); }
DeviceType device_type() const override { return helper_->device_type(); } DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); } const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name, const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name, const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override { int32_t index) const override {
return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override { const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {
return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index); return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);
} }
const ArgVec& inputs() const override { return helper_->inputs(); } const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); } const ArgVec& outputs() const override { return helper_->outputs(); }
const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); } const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }
private: private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name( const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override { const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name); return helper_->Attr4Name(call_ctx_, attr_name);
} }
const user_op::UserOpConfWrapper& user_op_conf() const override { const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf(); return helper_->user_op_conf();
} }
const UserKernelInitAndCacheContextHelper* helper_; const UserKernelInitAndCacheContextHelper* helper_;
eager::CallContext* call_ctx_; eager::CallContext* call_ctx_;
DeviceCtx* device_ctx_; DeviceCtx* device_ctx_;
}; };
namespace { namespace {
Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>& op_conf, Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>& op_conf,
const ArgVec& indexed_input_pairs, const ArgVec& indexed_input_pairs,
const ArgVec& indexed_output_pairs, const ArgVec& indexed_output_pairs,
std::vector<int64_t>* input_tuple_indexes4const_ibns, std::vector<int64_t>* input_tuple_indexes4const_ibns,
std::vector<int64_t>* input_tuple_indexes4mut_ibns, std::vector<int64_t>* input_tuple_indexes4mut_ibns,
std::vector<int64_t>* output_tuple_indexes4mut_obns, std::vector<int64_t>* output_tuple_indexes4mut_obns,
std::vector<int64_t>* output_tuple_indexes4mut2_obns) { std::vector<int64_t>* output_tuple_indexes4mut2_obns) {
const auto* op_reg_val = const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name()); user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val); CHECK_NOTNULL_OR_RETURN(op_reg_val);
ArgModifierSignature arg_modifier_signature; ArgModifierSignature arg_modifier_signature;
for (const auto& pair : indexed_input_pairs) { for (const auto& pair : indexed_input_pairs) {
const std::string ibn = GenRepeatedBn(pair.first, pair.second); const std::string ibn = GenRepeatedBn(pair.first, pair.second);
arg_modifier_signature.mutable_ibn2input_blob_modifier()->insert( arg_modifier_signature.mutable_ibn2input_blob_modifier()->insert(
{ibn, user_op::InputArgModifier()}); {ibn, user_op::InputArgModifier()});
} }
for (const auto& pair : indexed_output_pairs) { for (const auto& pair : indexed_output_pairs) {
const std::string obn = GenRepeatedBn(pair.first, pair.second); const std::string obn = GenRepeatedBn(pair.first, pair.second);
arg_modifier_signature.mutable_obn2output_blob_modifier()->insert( arg_modifier_signature.mutable_obn2output_blob_modifier()->insert(
{obn, user_op::OutputArgModifier()}); {obn, user_op::OutputArgModifier()});
} }
user_op::UserOpConfWrapper op_conf_wrapper(op_conf); user_op::UserOpConfWrapper op_conf_wrapper(op_conf);
if (op_reg_val->input_arg_modify_fn) { if (op_reg_val->input_arg_modify_fn) {
user_op::GetInputArgModifier GetInputArgModifierFn = user_op::GetInputArgModifier GetInputArgModifierFn =
[&arg_modifier_signature](const std::string& in_arg_name, [&arg_modifier_signature](const std::string& in_arg_name,
int32_t in_arg_index) -> user_op::InputArgModifier* { int32_t in_arg_index) -> user_op::InputArgModifier* {
const std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index); const std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index);
auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier(); auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier();
return &map->at(ibn); return &map->at(ibn);
}; };
JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper)); JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper));
} }
if (op_reg_val->output_arg_modify_fn) { if (op_reg_val->output_arg_modify_fn) {
user_op::GetOutputArgModifier GetOutputArgModifierFn = user_op::GetOutputArgModifier GetOutputArgModifierFn =
[&arg_modifier_signature](const std::string& in_arg_name, [&arg_modifier_signature](const std::string& in_arg_name,
int32_t in_arg_index) -> user_op::OutputArgModifier* { int32_t in_arg_index) -> user_op::OutputArgModifier* {
const std::string obn = GenRepeatedBn(in_arg_name, in_arg_index); const std::string obn = GenRepeatedBn(in_arg_name, in_arg_index);
auto* map = arg_modifier_signature.mutable_obn2output_blob_modifier(); auto* map = arg_modifier_signature.mutable_obn2output_blob_modifier();
return &map->at(obn); return &map->at(obn);
}; };
JUST(op_reg_val->output_arg_modify_fn(GetOutputArgModifierFn, op_conf_wrapper)); JUST(op_reg_val->output_arg_modify_fn(GetOutputArgModifierFn, op_conf_wrapper));
} }
for (int i = 0; i < indexed_input_pairs.size(); i++) { for (int i = 0; i < indexed_input_pairs.size(); i++) {
const auto& pair = indexed_input_pairs.at(i); const auto& pair = indexed_input_pairs.at(i);
const std::string ibn = GenRepeatedBn(pair.first, pair.second); const std::string ibn = GenRepeatedBn(pair.first, pair.second);
if (arg_modifier_signature.ibn2input_blob_modifier().at(ibn).is_mutable()) { if (arg_modifier_signature.ibn2input_blob_modifier().at(ibn).is_mutable()) {
input_tuple_indexes4mut_ibns->emplace_back(i); input_tuple_indexes4mut_ibns->emplace_back(i);
} else { } else {
input_tuple_indexes4const_ibns->emplace_back(i); input_tuple_indexes4const_ibns->emplace_back(i);
} }
} }
for (int i = 0; i < indexed_output_pairs.size(); i++) { for (int i = 0; i < indexed_output_pairs.size(); i++) {
const auto& pair = indexed_output_pairs.at(i); const auto& pair = indexed_output_pairs.at(i);
const std::string obn = GenRepeatedBn(pair.first, pair.second); const std::string obn = GenRepeatedBn(pair.first, pair.second);
if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) { if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) {
output_tuple_indexes4mut_obns->emplace_back(i); output_tuple_indexes4mut_obns->emplace_back(i);
} else { } else {
output_tuple_indexes4mut2_obns->emplace_back(i); output_tuple_indexes4mut2_obns->emplace_back(i);
} }
} }
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
} // namespace } // namespace
/* static */ Maybe<StatefulOpKernel> StatefulOpKernel::New( /* static */ Maybe<StatefulOpKernel> StatefulOpKernel::New(
const std::shared_ptr<OperatorConf>& op_conf, const Symbol<Stream>& stream, const std::shared_ptr<OperatorConf>& op_conf, const Symbol<Stream>& stream,
const AttrMap& base_attrs, const std::shared_ptr<const ParallelDesc>& parallel_desc, const AttrMap& base_attrs, const std::shared_ptr<const ParallelDesc>& parallel_desc,
const std::shared_ptr<const ArgTuple>& input_arg_tuple, const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) { const std::shared_ptr<const ArgTuple>& output_arg_tuple) {
auto opkernel = std::shared_ptr<StatefulOpKernel>(new StatefulOpKernel()); auto opkernel = std::shared_ptr<StatefulOpKernel>(new StatefulOpKernel());
opkernel->base_attrs_ = base_attrs; opkernel->base_attrs_ = base_attrs;
opkernel->op_conf_ = op_conf; opkernel->op_conf_ = op_conf;
opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf)); opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf));
opkernel->stream_ = stream; opkernel->stream_ = stream;
opkernel->input_arg_tuple_ = input_arg_tuple; opkernel->input_arg_tuple_ = input_arg_tuple;
opkernel->output_arg_tuple_ = output_arg_tuple; opkernel->output_arg_tuple_ = output_arg_tuple;
opkernel->need_check_mem_case_ = true; opkernel->need_check_mem_case_ = true;
const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag())); const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag()));
const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get(); const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get();
opkernel->op_infer_ctx_helper_.reset( opkernel->op_infer_ctx_helper_.reset(
new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple)); new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper( opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper(
device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_, device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_,
opkernel->output_arg_tuple_)); opkernel->output_arg_tuple_));
opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper( opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper(
device_type, user_op_conf, input_arg_tuple, output_arg_tuple)); device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->reg_ctx_helper_.reset( opkernel->reg_ctx_helper_.reset(
new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple)); new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
const auto* op_reg_val = const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name()); user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val); CHECK_NOTNULL_OR_RETURN(op_reg_val);
if (op_reg_val->logical_tensor_desc_infer_fn) { if (op_reg_val->logical_tensor_desc_infer_fn) {
opkernel->tensor_desc_infer_fn_ = op_reg_val->logical_tensor_desc_infer_fn; opkernel->tensor_desc_infer_fn_ = op_reg_val->logical_tensor_desc_infer_fn;
} else { } else {
return Error::UnimplementedError(); return Error::UnimplementedError();
} }
opkernel->data_type_infer_fn_ = op_reg_val->data_type_infer_fn; opkernel->data_type_infer_fn_ = op_reg_val->data_type_infer_fn;
JUST(InitTensorTupleIndexes4Bns( JUST(InitTensorTupleIndexes4Bns(
op_conf, input_arg_tuple->indexed_arg_name_and_index(), op_conf, input_arg_tuple->indexed_arg_name_and_index(),
output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_, output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_,
&opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_, &opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_,
&opkernel->output_tuple_indexes4mut2_obns_)); &opkernel->output_tuple_indexes4mut2_obns_));
return opkernel; return opkernel;
} }
StatefulOpKernel::~StatefulOpKernel() = default; StatefulOpKernel::~StatefulOpKernel() = default;
size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx, size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx,
const user_op::OpKernel* user_opkernel) const { const user_op::OpKernel* user_opkernel) const {
UserOpInferContext op_infer_ctx(op_infer_ctx_helper_.get(), call_ctx); UserOpInferContext op_infer_ctx(op_infer_ctx_helper_.get(), call_ctx);
const auto& InferTmpSizeFn = GetInferTmpSizeFn(user_opkernel); const auto& InferTmpSizeFn = GetInferTmpSizeFn(user_opkernel);
return InferTmpSizeFn(&op_infer_ctx); return InferTmpSizeFn(&op_infer_ctx);
} }
Maybe<void> StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx, Maybe<void> StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx,
const user_op::OpKernel** user_opkernel, const user_op::OpKernel** user_opkernel,
bool* need_temp_storage) { bool* need_temp_storage) {
OF_PROFILER_RANGE_GUARD("ChooseOpKernel"); OF_PROFILER_RANGE_GUARD("ChooseOpKernel");
DataType primary_dtype = kInvalidDataType; DataType primary_dtype = kInvalidDataType;
const auto& inputs = call_ctx->inputs(); const auto& inputs = call_ctx->inputs();
const auto& outputs = call_ctx->outputs(); const auto& outputs = call_ctx->outputs();
if (likely(!inputs->empty())) { if (likely(!inputs->empty())) {
primary_dtype = (*inputs)[0]->data_type(); primary_dtype = (*inputs)[0]->data_type();
} else if (likely(!outputs->empty())) { } else if (likely(!outputs->empty())) {
primary_dtype = (*outputs)[0]->data_type(); primary_dtype = (*outputs)[0]->data_type();
} else { } else {
// do nothing // do nothing
} }
UserKernelRegContext reg_ctx(reg_ctx_helper_.get(), call_ctx); UserKernelRegContext reg_ctx(reg_ctx_helper_.get(), call_ctx);
for (const auto& pair : dtype2cached_kernels_[primary_dtype]) { for (const auto& pair : dtype2cached_kernels_[primary_dtype]) {
if (likely(pair.first->is_matched_hob->get(reg_ctx))) { if (likely(pair.first->is_matched_hob->get(reg_ctx))) {
*need_temp_storage = pair.first->need_temp_storage; *need_temp_storage = pair.first->need_temp_storage;
*user_opkernel = pair.second.get(); *user_opkernel = pair.second.get();
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
} }
OF_PROFILER_RANGE_GUARD("fallback"); OF_PROFILER_RANGE_GUARD("fallback");
const auto& op_type_name = user_op_conf_->op_type_name(); const auto& op_type_name = user_op_conf_->op_type_name();
const auto* kernel_reg_val = const auto* kernel_reg_val =
JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_type_name, reg_ctx)); JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_type_name, reg_ctx));
CHECK_NOTNULL(kernel_reg_val); CHECK_NOTNULL(kernel_reg_val);
auto* kernel = kernel_reg_val->create_fn(); auto* kernel = kernel_reg_val->create_fn();
dtype2cached_kernels_[primary_dtype].push_back( dtype2cached_kernels_[primary_dtype].push_back(
{kernel_reg_val, std::shared_ptr<const user_op::OpKernel>(kernel)}); {kernel_reg_val, std::shared_ptr<const user_op::OpKernel>(kernel)});
infer_tmp_size_fn_map_.emplace(kernel, &kernel_reg_val->infer_tmp_size_fn); infer_tmp_size_fn_map_.emplace(kernel, &kernel_reg_val->infer_tmp_size_fn);
*need_temp_storage = kernel_reg_val->need_temp_storage; *need_temp_storage = kernel_reg_val->need_temp_storage;
*user_opkernel = kernel; *user_opkernel = kernel;
return Maybe<void>::Ok(); return Maybe<void>::Ok();
} }
void StatefulOpKernel::TryInitOpKernelStateAndCache(eager::CallContext* call_ctx, void StatefulOpKernel::TryInitOpKernelStateAndCache(eager::CallContext* call_ctx,
DeviceCtx* device_ctx, DeviceCtx* device_ctx,
const user_op::OpKernel* op_kernel, const user_op::OpKernel* op_kernel,
user_op::OpKernelState** state, user_op::OpKernelState** state,
user_op::OpKernelCache** cache) { user_op::OpKernelCache** cache) {
UserKernelInitAndCacheContext init_and_cache_ctx(init_and_cache_ctx_helper_.get(), call_ctx, UserKernelInitAndCacheContext init_and_cache_ctx(init_and_cache_ctx_helper_.get(), call_ctx,
device_ctx); device_ctx);
if (state != nullptr) { if (state != nullptr) {
auto it = op_kernel_state_map_.find(op_kernel); auto it = op_kernel_state_map_.find(op_kernel);
if (it != op_kernel_state_map_.end()) { if (it != op_kernel_state_map_.end()) {
*state = it->second.get(); *state = it->second.get();
} else { } else {
auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx); auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx);
op_kernel_state_map_.emplace(op_kernel, created_state); op_kernel_state_map_.emplace(op_kernel, created_state);
*state = created_state.get(); *state = created_state.get();
} }
} }
{ {
auto& cache_in_map = op_kernel_cache_map_[op_kernel]; auto& cache_in_map = op_kernel_cache_map_[op_kernel];
op_kernel->InitOpKernelCacheWithFlags(&init_and_cache_ctx, op_kernel->InitOpKernelCacheWithFlags(&init_and_cache_ctx,
user_op::OpKernelCache::kAllMayChanged, &cache_in_map); user_op::OpKernelCache::kAllMayChanged, &cache_in_map);
*cache = cache_in_map.get(); *cache = cache_in_map.get();
} }
} }
const user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn( const user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn(
const user_op::OpKernel* op_kernel) const { const user_op::OpKernel* op_kernel) const {
return *infer_tmp_size_fn_map_.at(op_kernel); return *infer_tmp_size_fn_map_.at(op_kernel);
} }
user_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const { user_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const {
return tensor_desc_infer_fn_; return tensor_desc_infer_fn_;
} }
user_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; } user_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; }
void StatefulOpKernel::Compute(eager::CallContext* call_ctx, DeviceCtx* device_ctx, void StatefulOpKernel::Compute(eager::CallContext* call_ctx, DeviceCtx* device_ctx,
const user_op::OpKernel* user_opkernel, const user_op::OpKernel* user_opkernel,
user_op::OpKernelState* state, user_op::OpKernelState* state,
const user_op::OpKernelCache* cache) const { const user_op::OpKernelCache* cache) const {
UserKernelComputeContext compute_context(compute_ctx_helper_.get(), call_ctx, device_ctx); UserKernelComputeContext compute_context(compute_ctx_helper_.get(), call_ctx, device_ctx);
auto* compute_ctx = &compute_context; auto* compute_ctx = &compute_context;
OF_PROFILER_RANGE_GUARD("Compute"); OF_PROFILER_RANGE_GUARD("Compute");
if (Singleton<profiler::ProfileManager>::Get()) { if (Singleton<profiler::ProfileManager>::Get()) {
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
const auto CalMemorySize = [compute_ctx](const one::ArgVec& args) -> int64_t { const auto CalMemorySize = [compute_ctx](const one::ArgVec& args) -> int64_t {
const auto Func = [compute_ctx](int64_t mem_size, const auto& pair) { const auto Func = [compute_ctx](int64_t mem_size, const auto& pair) {
const auto tensor = compute_ctx->Tensor4ArgNameAndIndex(pair.first, pair.second); const auto tensor = compute_ctx->Tensor4ArgNameAndIndex(pair.first, pair.second);
return mem_size + tensor->shape_view().elem_cnt() * GetSizeOfDataType(tensor->data_type()); return mem_size + tensor->shape_view().elem_cnt() * GetSizeOfDataType(tensor->data_type());
}; };
return std::accumulate(args.begin(), args.end(), static_cast<int64_t>(0), Func); return std::accumulate(args.begin(), args.end(), static_cast<int64_t>(0), Func);
}; };
#endif #endif
auto er_guard = CHECK_JUST(profiler::EventRecorder::CreateKernelEventRecorder( auto er_guard = CHECK_JUST(profiler::EventRecorder::CreateKernelEventRecorder(
op_type_name(), op_type_name(),
#if defined(WITH_CUDA) #if defined(WITH_CUDA) || defined(WITH_ROCM)
[compute_ctx, CalMemorySize]() -> int64_t { [compute_ctx, CalMemorySize]() -> int64_t {
return CalMemorySize(compute_ctx->inputs()) + CalMemorySize(compute_ctx->outputs()); return CalMemorySize(compute_ctx->inputs()) + CalMemorySize(compute_ctx->outputs());
}, },
#endif #endif
[compute_ctx]() -> std::vector<ShapeView> { [compute_ctx]() -> std::vector<Shape> {
std::vector<ShapeView> shapes; std::vector<Shape> shapes;
for (const auto& pair : compute_ctx->inputs()) { for (const auto& pair : compute_ctx->inputs()) {
shapes.emplace_back( shapes.emplace_back(
compute_ctx->TensorDesc4ArgNameAndIndex(pair.first, pair.second)->shape()); compute_ctx->TensorDesc4ArgNameAndIndex(pair.first, pair.second)->shape());
} }
return shapes; return shapes;
})); }));
user_opkernel->Compute(compute_ctx, state, cache); user_opkernel->Compute(compute_ctx, state, cache);
} else { } else {
user_opkernel->Compute(compute_ctx, state, cache); user_opkernel->Compute(compute_ctx, state, cache);
} }
} }
} // namespace one } // namespace one
} // namespace oneflow } // namespace oneflow
import numpy as np
import oneflow as flow
def fused_dot_feature_interaction(x,
y,
self_interaction=False,
output_padding=0,
output_concat=None,
dtype=flow.float32
):
# (bs, es) = x.shape
(bs, dims, es) = y.shape
if self_interaction:
offset = 1
else:
offset = 0
li = flow.tensor([i for i in range(dims + 1) for j in range(i + offset)])
lj = flow.tensor([j for i in range(dims + 1) for j in range(i + offset)])
T = flow.cat(
[
flow.reshape(x, (bs, 1, es)),
y,
],
dim=1,
)
Z = flow.matmul(T, T, transpose_b=True)
# gather_nd not support half, so cast to float32
Z = flow.cast(Z, flow.float32)
Zflat = Z[:, li, lj]
Zflat = flow.cast(Zflat, dtype)
if output_concat is not None:
R = flow.cat([output_concat, Zflat], dim=1)
else:
R = Zflat
if output_padding != 0:
padding_tensor = flow.tensor(
np.zeros((bs, output_padding)).astype(np.float32),
device="cuda",
requires_grad=False,
)
R = flow.cat([R, padding_tensor], dim=1)
return R
""" """
Copyright 2020 The OneFlow Authors. All rights reserved. Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import os import os
import unittest import unittest
import oneflow.unittest import oneflow.unittest
import oneflow as flow import oneflow as flow
import oneflow.nn as nn import oneflow.nn as nn
import oneflow.nn.functional as F import oneflow.nn.functional as F
import oneflow.profiler import oneflow.profiler
from oneflow.profiler.events import CustomEvent, KernelEvent from oneflow.profiler.events import CustomEvent, KernelEvent
class LeNet(nn.Module): class LeNet(nn.Module):
def __init__(self): def __init__(self):
super(LeNet, self).__init__() super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5) self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84) self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) self.fc3 = nn.Linear(84, 10)
def forward(self, x): def forward(self, x):
out = F.relu(self.conv1(x)) out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2) out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out)) out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2) out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1) out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out)) out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out)) out = F.relu(self.fc2(out))
out = self.fc3(out) out = self.fc3(out)
return out return out
def get_event(events, name: str, input_shapes: str = "-"): def get_event(events, name: str, input_shapes: str = "-"):
for item in events: for item in events:
if isinstance(item, CustomEvent): if isinstance(item, CustomEvent):
if item.name == name: if item.name == name:
return item return item
if isinstance(item, KernelEvent): if isinstance(item, KernelEvent):
if item.name == name and item.input_shapes == input_shapes: if item.name == name and item.input_shapes == input_shapes:
return item return item
return None return None
def _test_lenet( def _test_lenet(
test_case, test_case,
on_cuda: bool, on_cuda: bool,
record_shapes: bool, record_shapes: bool,
record_bandwidth_for_cuda: bool = False, record_bandwidth_for_cuda: bool = False,
): ):
x = flow.randn(2, 3, 32, 32) x = flow.randn(2, 3, 32, 32)
lenet = LeNet() lenet = LeNet()
if on_cuda: if on_cuda:
x = x.to("cuda") x = x.to("cuda")
lenet.to("cuda") lenet.to("cuda")
activities = [oneflow.profiler.ProfilerActivity.CPU] activities = [oneflow.profiler.ProfilerActivity.CPU]
if on_cuda: if on_cuda:
activities.append(oneflow.profiler.ProfilerActivity.CUDA) activities.append(oneflow.profiler.ProfilerActivity.CUDA)
with oneflow.profiler.profile( with oneflow.profiler.profile(
activities=activities, activities=activities,
record_shapes=record_shapes, record_shapes=record_shapes,
record_bandwidth_for_cuda=record_bandwidth_for_cuda, record_bandwidth_for_cuda=record_bandwidth_for_cuda,
) as prof: ) as prof:
with oneflow.profiler.record_function("lenet_forward_total_time") as f: with oneflow.profiler.record_function("lenet_forward_total_time") as f:
for _ in range(2): for _ in range(2):
eager_res = lenet(x) eager_res = lenet(x)
with oneflow.profiler.record_function("lenet_backward_total_time") as f: with oneflow.profiler.record_function("lenet_backward_total_time") as f:
eager_res.sum().backward() eager_res.sum().backward()
events = prof.key_averages(group_by_input_shape=True) events = prof.key_averages(group_by_input_shape=True)
print(events)
conv_event = get_event( conv_event = get_event(
events, "conv2d", "[(2,3,32,32), (6,3,5,5)]" if record_shapes else "-" events, "conv2d", "[(2,3,32,32), (6,3,5,5)]" if record_shapes else "-"
) )
test_case.assertIsNotNone(conv_event) test_case.assertIsNotNone(conv_event)
if on_cuda: if on_cuda:
test_case.assertGreater(conv_event.cpu_time, 0.0) test_case.assertGreater(conv_event.cpu_time, 0.0)
test_case.assertGreater(conv_event.cpu_time_total, 0.0) test_case.assertGreater(conv_event.cpu_time_total, 0.0)
test_case.assertGreater(conv_event.cuda_time, 0.0) test_case.assertGreater(conv_event.cuda_time, 0.0)
test_case.assertGreater(conv_event.cuda_time_total, 0.0) test_case.assertGreater(conv_event.cuda_time_total, 0.0)
else: else:
test_case.assertGreater(conv_event.cpu_time, 0.0) test_case.assertGreater(conv_event.cpu_time, 0.0)
test_case.assertGreater(conv_event.cpu_time_total, 0.0) test_case.assertGreater(conv_event.cpu_time_total, 0.0)
test_case.assertEqual(conv_event.count, 2 if record_shapes else 4) test_case.assertEqual(conv_event.count, 2 if record_shapes else 4)
if record_bandwidth_for_cuda and on_cuda: if record_bandwidth_for_cuda and on_cuda:
test_case.assertNotEqual(conv_event.bandwidth, -1) test_case.assertNotEqual(conv_event.bandwidth, -1)
relu_grad_event = get_event( relu_grad_event = get_event(
events, "relu_grad", "[(2,6,28,28), (2,6,28,28)]" if record_shapes else "-" events, "relu_grad", "[(2,6,28,28), (2,6,28,28)]" if record_shapes else "-"
) )
test_case.assertIsNotNone(relu_grad_event) test_case.assertIsNotNone(relu_grad_event)
if on_cuda: if on_cuda:
test_case.assertGreater(relu_grad_event.cpu_time, 0.0) test_case.assertGreater(relu_grad_event.cpu_time, 0.0)
test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0) test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)
test_case.assertGreater(relu_grad_event.cuda_time, 0.0) test_case.assertGreater(relu_grad_event.cuda_time, 0.0)
test_case.assertGreater(relu_grad_event.cuda_time_total, 0.0) test_case.assertGreater(relu_grad_event.cuda_time_total, 0.0)
else: else:
test_case.assertGreater(relu_grad_event.cpu_time, 0.0) test_case.assertGreater(relu_grad_event.cpu_time, 0.0)
test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0) test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)
test_case.assertEqual(relu_grad_event.count, 1 if record_shapes else 4) test_case.assertEqual(relu_grad_event.count, 1 if record_shapes else 4)
if record_bandwidth_for_cuda and on_cuda: if record_bandwidth_for_cuda and on_cuda:
test_case.assertNotEqual(relu_grad_event.bandwidth, -1) test_case.assertNotEqual(relu_grad_event.bandwidth, -1)
test_case.assertIsNotNone(get_event(events, "lenet_forward_total_time")) test_case.assertIsNotNone(get_event(events, "lenet_forward_total_time"))
test_case.assertIsNotNone(get_event(events, "lenet_backward_total_time")) test_case.assertIsNotNone(get_event(events, "lenet_backward_total_time"))
class TestProfileLenet(flow.unittest.TestCase): class TestProfileLenet(flow.unittest.TestCase):
def test_lenet_cpu(test_case): def test_lenet_cpu(test_case):
_test_lenet(test_case, on_cuda=False, record_shapes=True) _test_lenet(test_case, on_cuda=False, record_shapes=True)
_test_lenet(test_case, on_cuda=False, record_shapes=False) _test_lenet(test_case, on_cuda=False, record_shapes=False)
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_lenet_cuda(test_case): def test_lenet_cuda(test_case):
_test_lenet( _test_lenet(
test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=False test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=False
) )
_test_lenet( _test_lenet(
test_case, test_case,
on_cuda=True, on_cuda=True,
record_shapes=False, record_shapes=False,
record_bandwidth_for_cuda=False, record_bandwidth_for_cuda=False,
) )
_test_lenet( _test_lenet(
test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=True test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=True
) )
_test_lenet( _test_lenet(
test_case, on_cuda=True, record_shapes=False, record_bandwidth_for_cuda=True test_case, on_cuda=True, record_shapes=False, record_bandwidth_for_cuda=True
) )
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment