Commit f262efc9 authored by yuguo's avatar yuguo
Browse files

Surpport profiler for DCU, surpport debug compiler

parent 3f56062c
......@@ -265,9 +265,9 @@ set(ROBIN_HOOD_HASHING_URL
use_mirror(VARIABLE ROBIN_HOOD_HASHING_URL URL ${ROBIN_HOOD_HASHING_URL})
set(ROBIN_HOOD_HASHING_MD5 a78bd30a7582f25984f8592652836467)
set(FMT_URL https://github.com/fmtlib/fmt/archive/48b7e3dafb27ece02cd6addc8bd1041c79d59c2c.zip)
set(FMT_URL https://github.com/fmtlib/fmt/archive/fc07217d85e6dcec52878807d6bbd89a9d9156a5.zip)
use_mirror(VARIABLE FMT_URL URL ${FMT_URL})
set(FMT_MD5 45925a979ed7195e0c88a70be691de09)
set(FMT_MD5 7d9bb2ececc9ede29cd35bdc42a7e22c)
set(KINETO_URL
https://github.com/pytorch/kineto/archive/ff8dba20499a660650632952be76450bd70a52a6.zip)
......
......@@ -175,6 +175,8 @@ if (BUILD_ROCM)
add_definitions(-D__HIP_PLATFORM_HCC__)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__ --gpu-max-threads-per-block=1024")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -mcmodel=large")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -mcmodel=large")
list(APPEND oneflow_third_party_libs hip::device)
list(APPEND oneflow_third_party_libs roc::hipblas)
list(APPEND oneflow_third_party_libs hip::hipcub)
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// #include "fmt/core.h"
// #include "fmt/format.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/util.h"
using json = nlohmann::json;
namespace oneflow {
namespace profiler {
nlohmann::json IEvent::ToJson() {
return json{{"name", name_}, {"time", GetDuration<double>()}, {"input_shapes", "-"}};
}
void IEvent::SetStartedAt(double t) { started_at_ = t; }
void IEvent::SetFinishedAt(double t) { finished_at_ = t; }
void IEvent::Start() { SetStartedAt(GetTimeNow()); }
void IEvent::Finish() { SetFinishedAt(GetTimeNow()); }
bool IEvent::IsChildOf(const IEvent* e) {
if (!e) { return false; }
if (this == e) { return false; }
return GetStartedAt<double>() >= e->GetStartedAt<double>()
&& GetFinishedAt<double>() <= e->GetFinishedAt<double>();
}
const std::string& IEvent::GetName() const { return name_; }
std::string CustomEvent::Key() { return name_; }
nlohmann::json CustomEvent::ToJson() {
auto j = IEvent::ToJson();
j["type"] = EventType::kCustom;
j["custom_type"] = type_;
return j;
}
std::shared_ptr<CustomEvent> CustomEvent::Create(const std::string& name, CustomEventType type) {
return std::shared_ptr<CustomEvent>(new CustomEvent(name, type));
}
// std::string KernelEvent::Key() { return fmt::format("{}.{}", name_, GetFormatedInputShapes()); }
std::string KernelEvent::Key() { return "yuguo"; }
nlohmann::json KernelEvent::ToJson() {
auto j = IEvent::ToJson();
j["type"] = EventType::kOneflowKernel;
j["input_shapes"] = GetFormatedInputShapes();
#if defined(WITH_CUDA)
j["memory_size"] = memory_size_;
if (!children_.empty()) { j["children"] = children_; }
#endif // WITH_CUDA
return j;
}
std::shared_ptr<KernelEvent> KernelEvent::Create(
const std::string& name, const std::function<std::vector<ShapeView>(void)>& shape_getter) {
return std::shared_ptr<KernelEvent>(new KernelEvent(name, shape_getter));
}
void KernelEvent::RecordShape(const ShapeView& shape) { input_shapes_.emplace_back(shape); }
std::string KernelEvent::GetFormatedInputShapes(size_t max_num_to_format) {
if (input_shapes_.size() == 0) { return "-"; }
std::vector<std::string> shapes_formated(std::min(input_shapes_.size(), max_num_to_format));
for (auto i = 0; i < shapes_formated.size(); ++i) {
const std::string current_shape = input_shapes_[i].ToString();
shapes_formated[i] = current_shape == "()" ? "scalar" : current_shape;
}
if (input_shapes_.size() > max_num_to_format) { shapes_formated.emplace_back("..."); }
// return fmt::format("[{}]", fmt::join(shapes_formated, ", "));
return "yuguo";
}
} // namespace profiler
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "fmt/core.h"
#include "fmt/format.h"
#include "oneflow/core/profiler/event.h"
#include "oneflow/core/profiler/util.h"
using json = nlohmann::json;
namespace oneflow {
namespace profiler {
nlohmann::json IEvent::ToJson() {
return json{{"name", name_}, {"time", GetDuration<double>()}, {"input_shapes", "-"}};
}
void IEvent::SetStartedAt(double t) { started_at_ = t; }
void IEvent::SetFinishedAt(double t) { finished_at_ = t; }
void IEvent::Start() { SetStartedAt(GetTimeNow()); }
void IEvent::Finish() { SetFinishedAt(GetTimeNow()); }
bool IEvent::IsChildOf(const IEvent* e) {
if (!e) { return false; }
if (this == e) { return false; }
return GetStartedAt<double>() >= e->GetStartedAt<double>()
&& GetFinishedAt<double>() <= e->GetFinishedAt<double>();
}
const std::string& IEvent::GetName() const { return name_; }
std::string CustomEvent::Key() { return name_; }
nlohmann::json CustomEvent::ToJson() {
auto j = IEvent::ToJson();
j["type"] = EventType::kCustom;
j["custom_type"] = type_;
return j;
}
std::shared_ptr<CustomEvent> CustomEvent::Create(const std::string& name, CustomEventType type) {
return std::shared_ptr<CustomEvent>(new CustomEvent(name, type));
}
std::string KernelEvent::Key() { return fmt::format("{}.{}", name_, GetFormatedInputShapes()); }
nlohmann::json KernelEvent::ToJson() {
auto j = IEvent::ToJson();
j["type"] = EventType::kOneflowKernel;
j["input_shapes"] = GetFormatedInputShapes();
#if defined(WITH_CUDA) || defined(WITH_ROCM)
j["memory_size"] = memory_size_;
if (!children_.empty()) { j["children"] = children_; }
#endif // WITH_CUDA
return j;
}
std::shared_ptr<KernelEvent> KernelEvent::Create(
const std::string& name, const std::function<std::vector<Shape>(void)>& shape_getter) {
return std::shared_ptr<KernelEvent>(new KernelEvent(name, shape_getter));
}
std::string KernelEvent::GetFormatedInputShapes(size_t max_num_to_format) {
if (input_shapes_.size() == 0) { return "-"; }
std::vector<std::string> shapes_formated(std::min(input_shapes_.size(), max_num_to_format));
for (auto i = 0; i < shapes_formated.size(); ++i) {
const std::string current_shape = input_shapes_[i].ToString();
shapes_formated[i] = current_shape == "()" ? "scalar" : current_shape;
}
if (input_shapes_.size() > max_num_to_format) { shapes_formated.emplace_back("..."); }
return fmt::format("[{}]", fmt::join(shapes_formated, ", "));
}
} // namespace profiler
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_
#include <functional>
#include <memory>
#include <vector>
#include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape_view.h"
namespace oneflow {
namespace profiler {
class ProfileManager;
enum class EventType {
kCustom, // has three kinds
kOneflowKernel // OneFlow cpu/cuda kernel
};
enum class CustomEventType {
kDefault, // for record_function
kCudaKernel, // cuda kernel
kCudaRuntime // something like cudaLaunchKernel
};
enum class EventTimeUnit { kNS, kUS };
class IEvent {
public:
OF_DISALLOW_COPY_AND_MOVE(IEvent);
IEvent() = delete;
IEvent(const std::string& name, EventTimeUnit time_unit) : name_(name), time_unit_(time_unit) {}
virtual std::string Key() = 0;
virtual nlohmann::json ToJson();
virtual ~IEvent() = default;
virtual void Start();
virtual void Finish();
bool IsChildOf(const IEvent* e);
const std::string& GetName() const;
template<typename T>
const T GetDuration(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
template<typename T>
const T GetStartedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
template<typename T>
const T GetFinishedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
protected:
virtual void SetStartedAt(double t);
virtual void SetFinishedAt(double t);
std::string name_;
EventTimeUnit time_unit_;
double started_at_ = 0;
double finished_at_ = 0;
};
inline double ConvertTime(double time_, EventTimeUnit src_time_unit, EventTimeUnit dst_time_unit) {
if (src_time_unit == EventTimeUnit::kNS && dst_time_unit == EventTimeUnit::kUS) {
return time_ / 1000;
}
if (src_time_unit == EventTimeUnit::kUS && dst_time_unit == EventTimeUnit::kNS) {
return time_ * 1000;
}
return time_;
}
template<>
const inline double IEvent::GetStartedAt<double>(EventTimeUnit time_unit) const {
return ConvertTime(started_at_, time_unit_, time_unit);
}
template<>
const inline time_t IEvent::GetStartedAt<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetStartedAt<double>(time_unit));
}
template<>
const inline double IEvent::GetFinishedAt<double>(EventTimeUnit time_unit) const {
return ConvertTime(finished_at_, time_unit_, time_unit);
}
template<>
const inline time_t IEvent::GetFinishedAt<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetFinishedAt<double>(time_unit));
}
template<>
const inline double IEvent::GetDuration<double>(EventTimeUnit time_unit) const {
return GetFinishedAt<double>(time_unit) - GetStartedAt<double>(time_unit);
}
template<>
const inline time_t IEvent::GetDuration<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetDuration<double>(time_unit));
}
class CustomEvent final : public IEvent {
public:
friend class ProfileManager;
std::string Key() override;
nlohmann::json ToJson() override;
static std::shared_ptr<CustomEvent> Create(const std::string& name,
CustomEventType type = CustomEventType::kDefault);
private:
CustomEventType type_;
CustomEvent(const std::string& custom_name, CustomEventType type)
: IEvent(custom_name,
type == CustomEventType::kDefault ? EventTimeUnit::kNS : EventTimeUnit::kUS),
type_(type) {}
};
class KernelEvent final : public IEvent {
public:
std::string Key() override;
nlohmann::json ToJson() override;
static std::shared_ptr<KernelEvent> Create(
const std::string& name, const std::function<std::vector<ShapeView>(void)>& shape_getter);
void RecordShape(const ShapeView& shape);
#if defined(WITH_CUDA)
void SetMemorySize(int64_t memory_size) { memory_size_ = memory_size; }
void AddChildEvent(const std::shared_ptr<IEvent>& e) { children_.emplace(e); }
bool AddChildEventIfSo(const std::shared_ptr<IEvent>& e) {
if (e->IsChildOf(dynamic_cast<IEvent*>(this))) {
children_.emplace(e);
return true;
}
return false;
}
bool HasChildEvent(const std::shared_ptr<IEvent>& e) { return children_.count(e); }
void WalkAmongChildren(const std::function<void(const std::shared_ptr<IEvent>& e)>& f) const {
for (const auto& x : children_) { f(x); }
}
#endif // WITH_CUDA
private:
KernelEvent(const std::string& kernel_name,
const std::function<std::vector<ShapeView>(void)>& shape_getter)
: IEvent(kernel_name, EventTimeUnit::kNS) {
if (shape_getter) { input_shapes_ = shape_getter(); }
}
#if defined(WITH_CUDA)
int64_t memory_size_ = -1;
std::set<std::shared_ptr<IEvent>> children_;
#endif // WITH_CUDA
std::vector<ShapeView> input_shapes_;
std::string GetFormatedInputShapes(size_t max_num_to_format = 4);
};
} // namespace profiler
} // namespace oneflow
namespace nlohmann {
inline void to_json(json& j, const std::shared_ptr<::oneflow::profiler::IEvent>& event) {
j = event->ToJson();
}
} // namespace nlohmann
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_H_
#define ONEFLOW_CORE_PROFILER_EVENT_H_
#include <functional>
#include <memory>
#include <vector>
#include "nlohmann/json.hpp"
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape_view.h"
namespace oneflow {
namespace profiler {
class ProfileManager;
enum class EventType {
kCustom, // has three kinds
kOneflowKernel // OneFlow cpu/cuda kernel
};
enum class CustomEventType {
kDefault, // for record_function
kCudaKernel, // cuda kernel
kCudaRuntime // something like cudaLaunchKernel
};
enum class EventTimeUnit { kNS, kUS };
class IEvent {
public:
OF_DISALLOW_COPY_AND_MOVE(IEvent);
IEvent() = delete;
IEvent(const std::string& name, EventTimeUnit time_unit) : name_(name), time_unit_(time_unit) {}
virtual std::string Key() = 0;
virtual nlohmann::json ToJson();
virtual ~IEvent() = default;
virtual void Start();
virtual void Finish();
bool IsChildOf(const IEvent* e);
const std::string& GetName() const;
template<typename T>
const T GetDuration(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
template<typename T>
const T GetStartedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
template<typename T>
const T GetFinishedAt(EventTimeUnit time_unit = EventTimeUnit::kUS) const;
protected:
virtual void SetStartedAt(double t);
virtual void SetFinishedAt(double t);
std::string name_;
EventTimeUnit time_unit_;
double started_at_ = 0;
double finished_at_ = 0;
};
inline double ConvertTime(double time_, EventTimeUnit src_time_unit, EventTimeUnit dst_time_unit) {
if (src_time_unit == EventTimeUnit::kNS && dst_time_unit == EventTimeUnit::kUS) {
return time_ / 1000;
}
if (src_time_unit == EventTimeUnit::kUS && dst_time_unit == EventTimeUnit::kNS) {
return time_ * 1000;
}
return time_;
}
template<>
const inline double IEvent::GetStartedAt<double>(EventTimeUnit time_unit) const {
return ConvertTime(started_at_, time_unit_, time_unit);
}
template<>
const inline time_t IEvent::GetStartedAt<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetStartedAt<double>(time_unit));
}
template<>
const inline double IEvent::GetFinishedAt<double>(EventTimeUnit time_unit) const {
return ConvertTime(finished_at_, time_unit_, time_unit);
}
template<>
const inline time_t IEvent::GetFinishedAt<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetFinishedAt<double>(time_unit));
}
template<>
const inline double IEvent::GetDuration<double>(EventTimeUnit time_unit) const {
return GetFinishedAt<double>(time_unit) - GetStartedAt<double>(time_unit);
}
template<>
const inline time_t IEvent::GetDuration<time_t>(EventTimeUnit time_unit) const {
return static_cast<time_t>(GetDuration<double>(time_unit));
}
class CustomEvent final : public IEvent {
public:
friend class ProfileManager;
std::string Key() override;
nlohmann::json ToJson() override;
static std::shared_ptr<CustomEvent> Create(const std::string& name,
CustomEventType type = CustomEventType::kDefault);
private:
CustomEventType type_;
CustomEvent(const std::string& custom_name, CustomEventType type)
: IEvent(custom_name,
type == CustomEventType::kDefault ? EventTimeUnit::kNS : EventTimeUnit::kUS),
type_(type) {}
};
class KernelEvent final : public IEvent {
public:
std::string Key() override;
nlohmann::json ToJson() override;
static std::shared_ptr<KernelEvent> Create(
const std::string& name, const std::function<std::vector<Shape>(void)>& shape_getter);
#if defined(WITH_CUDA) || defined(WITH_ROCM)
void SetMemorySize(int64_t memory_size) { memory_size_ = memory_size; }
void AddChildEvent(const std::shared_ptr<IEvent>& e) { children_.emplace(e); }
bool AddChildEventIfSo(const std::shared_ptr<IEvent>& e) {
if (e->IsChildOf(dynamic_cast<IEvent*>(this))) {
children_.emplace(e);
return true;
}
return false;
}
bool HasChildEvent(const std::shared_ptr<IEvent>& e) { return children_.count(e); }
void WalkAmongChildren(const std::function<void(const std::shared_ptr<IEvent>& e)>& f) const {
for (const auto& x : children_) { f(x); }
}
#endif // WITH_CUDA
private:
KernelEvent(const std::string& kernel_name,
const std::function<std::vector<Shape>(void)>& shape_getter)
: IEvent(kernel_name, EventTimeUnit::kNS) {
if (shape_getter) { input_shapes_ = shape_getter(); }
}
#if defined(WITH_CUDA) || defined(WITH_ROCM)
int64_t memory_size_ = -1;
std::set<std::shared_ptr<IEvent>> children_;
#endif // WITH_CUDA
std::vector<Shape> input_shapes_;
std::string GetFormatedInputShapes(size_t max_num_to_format = 4);
};
} // namespace profiler
} // namespace oneflow
namespace nlohmann {
inline void to_json(json& j, const std::shared_ptr<::oneflow::profiler::IEvent>& event) {
j = event->ToJson();
}
} // namespace nlohmann
#endif // ONEFLOW_CORE_PROFILER_EVENT_H_
......@@ -32,13 +32,13 @@ std::shared_ptr<EventRecorder> EventRecorder::CreateCustomEventRecorder(const st
Maybe<EventRecorder> EventRecorder::CreateKernelEventRecorder(
const std::string& name,
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
const std::function<int64_t()>& memory_size_getter,
#endif
const ShapeGetterFuncType& shape_getter) {
auto pmgr = Singleton<ProfileManager>::Get();
if (pmgr) {
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
if (pmgr->use_cpu_ || pmgr->use_cuda_) {
auto event = KernelEvent::Create(name, pmgr->record_shapes_ ? shape_getter : nullptr);
if (pmgr->use_cuda_) {
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/profiler/event.h"
namespace oneflow {
namespace profiler {
class EventRecorder {
public:
using ShapeGetterFuncType = std::function<std::vector<ShapeView>(void)>;
OF_DISALLOW_COPY_AND_MOVE(EventRecorder);
explicit EventRecorder(const std::shared_ptr<IEvent>& event) : event_(event) {
CHECK_JUST(RegisterEventToProfileManager(event));
event_->Start();
}
Maybe<void> RegisterEventToProfileManager(const std::shared_ptr<IEvent>& event);
~EventRecorder() {
if (event_) {
event_->Finish();
event_.reset();
}
}
static std::shared_ptr<EventRecorder> CreateCustomEventRecorder(const std::string& name);
static Maybe<EventRecorder> CreateKernelEventRecorder(
const std::string& name,
#if defined(WITH_CUDA)
const std::function<int64_t()>& memory_size_getter,
#endif
const ShapeGetterFuncType& shape_getter);
private:
std::shared_ptr<IEvent> event_;
};
} // namespace profiler
} // namespace oneflow
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#define ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/profiler/event.h"
namespace oneflow {
namespace profiler {
class EventRecorder {
public:
using ShapeGetterFuncType = std::function<std::vector<Shape>(void)>;
OF_DISALLOW_COPY_AND_MOVE(EventRecorder);
explicit EventRecorder(const std::shared_ptr<IEvent>& event) : event_(event) {
CHECK_JUST(RegisterEventToProfileManager(event));
event_->Start();
}
Maybe<void> RegisterEventToProfileManager(const std::shared_ptr<IEvent>& event);
~EventRecorder() {
if (event_) {
event_->Finish();
event_.reset();
}
}
static std::shared_ptr<EventRecorder> CreateCustomEventRecorder(const std::string& name);
static Maybe<EventRecorder> CreateKernelEventRecorder(
const std::string& name,
#if defined(WITH_CUDA) || defined(WITH_ROCM)
const std::function<int64_t()>& memory_size_getter,
#endif
const ShapeGetterFuncType& shape_getter);
private:
std::shared_ptr<IEvent> event_;
};
} // namespace profiler
} // namespace oneflow
#endif // ONEFLOW_CORE_PROFILER_EVENT_RECORDER_H_
......@@ -17,7 +17,11 @@ limitations under the License.
#include "oneflow/core/profiler/kernel.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/kernel/kernel.h"
#ifdef WITH_ROCM
#include "oneflow/core/ep/rocm/cuda_stream.h"
#else
#include "oneflow/core/ep/cuda/cuda_stream.h"
#endif
#include "oneflow/core/lazy/actor/actor_context.h"
namespace oneflow {
......@@ -43,6 +47,11 @@ thread_local cudaEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
thread_local cudaEvent_t cuda_memory_bandwidth_profile_end_event = nullptr;
#endif // WITH_CUDA
#if defined(WITH_ROCM)
thread_local hipEvent_t cuda_memory_bandwidth_profile_start_event = nullptr;
thread_local hipEvent_t cuda_memory_bandwidth_profile_end_event = nullptr;
#endif // WITH_ROCM
} // namespace
void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel* kernel) {
......@@ -61,6 +70,22 @@ void TraceKernelForwardDataContentStart(KernelContext* kernel_ctx, const Kernel*
}
if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); }
#endif // WITH_CUDA
#if defined(WITH_ROCM)
if (profile_cuda_memory_bandwidth) {
auto* actor_context_provider = dynamic_cast<ActorContextProvider*>(kernel_ctx);
auto* cuda_stream = dynamic_cast<ep::CudaStream*>(kernel_ctx->stream());
if (cuda_stream != nullptr && actor_context_provider != nullptr) {
CHECK(cuda_memory_bandwidth_profile_start_event == nullptr);
CHECK(cuda_memory_bandwidth_profile_end_event == nullptr);
OF_CUDA_CHECK(hipEventCreate(&cuda_memory_bandwidth_profile_start_event));
OF_CUDA_CHECK(hipEventCreate(&cuda_memory_bandwidth_profile_end_event));
OF_CUDA_CHECK(
hipEventRecord(cuda_memory_bandwidth_profile_start_event, cuda_stream->cuda_stream()));
}
}
if (profile_kernel_forward_range) { OF_PROFILER_RANGE_PUSH(kernel->op_conf().name()); }
#endif // WITH_ROCM
}
void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* kernel) {
......@@ -103,6 +128,45 @@ void TraceKernelForwardDataContentEnd(KernelContext* kernel_ctx, const Kernel* k
}
}
#endif // WITH_CUDA
#if defined(WITH_ROCM)
if (profile_kernel_forward_range) { OF_PROFILER_RANGE_POP(); }
// The memory bandwidth profiler only works in lazy mode.
if (profile_cuda_memory_bandwidth) {
auto* cuda_stream = dynamic_cast<ep::CudaStream*>(kernel_ctx->stream());
auto* actor_context_provider = dynamic_cast<ActorContextProvider*>(kernel_ctx);
if (cuda_stream != nullptr && actor_context_provider != nullptr) {
hipEvent_t start_event = cuda_memory_bandwidth_profile_start_event;
hipEvent_t end_event = cuda_memory_bandwidth_profile_end_event;
cuda_memory_bandwidth_profile_start_event = nullptr;
cuda_memory_bandwidth_profile_end_event = nullptr;
CHECK_NOTNULL(start_event);
CHECK_NOTNULL(end_event);
OF_CUDA_CHECK(hipEventRecord(end_event, cuda_stream->cuda_stream()));
int64_t memory_size = 0;
for (const auto& bn : kernel->op_attribute().input_bns()) {
const Blob* blob = kernel_ctx->BnInOp2Blob(bn);
if (blob) { memory_size += blob->ByteSizeOfBlobBody(); }
}
for (const auto& bn : kernel->op_attribute().output_bns()) {
const Blob* blob = kernel_ctx->BnInOp2Blob(bn);
if (blob) { memory_size += blob->ByteSizeOfBlobBody(); }
}
const std::string op_name = kernel->op_conf().name();
actor_context_provider->GetActorContext()->AddCallback(
[start_event, end_event, memory_size, op_name]() {
float elapsed_ms = 0;
OF_CUDA_CHECK(hipEventElapsedTime(&elapsed_ms, start_event, end_event));
OF_CUDA_CHECK(hipEventDestroy(start_event));
OF_CUDA_CHECK(hipEventDestroy(end_event));
double bandwidth =
static_cast<double>(memory_size) / (1024.0 * 1024.0 * 1024.0) / (elapsed_ms / 1000);
LOG(INFO) << "PROFILER::KERNEL::CUDA_MEMORY_BANDWIDTH op_name: " << op_name
<< " elapsed(ms): " << elapsed_ms << " memory_size(Byte): " << memory_size
<< " bandwidth(GB/s): " << bandwidth;
});
}
}
#endif // WITH_ROCM
}
} // namespace profiler
......
......@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
#include "oneflow/core/profiler/kineto_shim.h"
#include "libkineto.h"
......
......@@ -16,7 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#define ONEFLOW_CORE_PROFILER_KINETO_SHIM_H_
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
#include <string>
#include <memory>
......
......@@ -15,12 +15,12 @@ limitations under the License.
*/
#include <memory>
#include <unordered_map>
// #include "fmt/core.h"
#include "fmt/core.h"
#include "nlohmann/json.hpp"
#include "oneflow/core/profiler/kineto_shim.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event.h"
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
#include <libkineto.h>
#endif // WITH_CUDA
......@@ -48,7 +48,7 @@ std::string ProfileManager::DumpResultsJson() {
}
std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
auto trace = StopTrace();
const auto& kineto_events = *(trace.get()->activities());
std::set<std::shared_ptr<IEvent>> custom_events;
......@@ -77,7 +77,7 @@ std::vector<std::shared_ptr<IEvent>> ProfileManager::ExportEvents() {
while (!events_.empty()) {
auto evt = events_.front();
events_.pop();
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
auto evt_kernel = std::dynamic_pointer_cast<KernelEvent>(evt);
if (evt_kernel) {
std::set<int64_t> current_corr_ids;
......@@ -106,8 +106,7 @@ std::string ProfileManager::GetNextEventRecorderKey(const std::string& name) {
} else {
event_recorders_last_id_[name]++;
}
// return fmt::format("{}.{}", name, event_recorders_last_id_[name]);
return "yuguo";
return fmt::format("{}.{}", name, event_recorders_last_id_[name]);
}
} // namespace profiler
......
......@@ -37,7 +37,7 @@ class ProfileManager {
use_cuda_(use_cuda),
record_shapes_(record_shapes),
record_bandwidth_(record_bandwidth) {
#if defined(WITH_CUDA)
#if defined(WITH_CUDA) || defined(WITH_ROCM)
std::set<ActivityType> activities{};
if (use_cpu) { activities.insert(ActivityType::CPU); }
if (use_cuda) { activities.insert(ActivityType::CUDA); }
......
......@@ -20,11 +20,20 @@ limitations under the License.
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/vm/vm_util.h"
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_profile.h>
#include <roctracer_roctx.h>
#include <sys/syscall.h>
#include <iostream>
#include "oneflow/core/device/cuda_util.h"
#else
#include <nvtx3/nvToolsExt.h>
#include <sys/syscall.h>
#include <iostream>
#include <cuda_profiler_api.h>
#include "oneflow/core/device/cuda_util.h"
#endif
#endif // OF_ENABLE_PROFILER
namespace oneflow {
......@@ -33,6 +42,16 @@ namespace profiler {
void NameThisHostThread(const std::string& name) {
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
static thread_local std::unique_ptr<std::string> thread_name_prefix;
if (!thread_name_prefix) {
thread_name_prefix.reset(
new std::string(GetStringFromEnv("ONEFLOW_PROFILER_HOST_THREAD_NAME_PREFIX", "")));
}
const std::string name_with_prefix = *thread_name_prefix + name;
// nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());
roctxMarkA(name_with_prefix.c_str());
#else
static thread_local std::unique_ptr<std::string> thread_name_prefix;
if (!thread_name_prefix) {
thread_name_prefix.reset(
......@@ -40,18 +59,27 @@ void NameThisHostThread(const std::string& name) {
}
const std::string name_with_prefix = *thread_name_prefix + name;
nvtxNameOsThreadA(syscall(SYS_gettid), name_with_prefix.c_str());
#endif
#endif // OF_ENABLE_PROFILER
}
void RangePush(const std::string& name) {
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePushA(name.c_str());
#else
nvtxRangePushA(name.c_str());
#endif
#endif // OF_ENABLE_PROFILER
}
void RangePop() {
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
roctxRangePop();
#else
nvtxRangePop();
#endif
#endif // OF_ENABLE_PROFILER
}
......@@ -82,13 +110,21 @@ void LogHostMemoryUsage(const std::string& name) {
void ProfilerStart() {
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipProfilerStart());
#else
OF_CUDA_CHECK(cudaProfilerStart());
#endif
#endif // OF_ENABLE_PROFILER
}
void ProfilerStop() {
#ifdef OF_ENABLE_PROFILER
#ifdef WITH_ROCM
OF_CUDA_CHECK(hipProfilerStop());
#else
OF_CUDA_CHECK(cudaProfilerStop());
#endif
#endif // OF_ENABLE_PROFILER
}
......@@ -105,6 +141,9 @@ Maybe<std::string> DisableProfilerAndReturnResult() {
#if defined(WITH_CUDA)
OF_CUDA_CHECK(cudaDeviceSynchronize());
#endif // WITH_CUDA
#if defined(WITH_ROCM)
OF_CUDA_CHECK(hipDeviceSynchronize());
#endif // WITH_ROCM
auto* pmgr = JUST(SingletonMaybe<ProfileManager>());
std::string results = pmgr->DumpResultsJson();
Singleton<ProfileManager>::Delete();
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/core/device/cuda_pseudo_half.h"
#if defined(__CUDACC__)
#include <cuda_fp16.h>
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#elif defined(__HIPCC__)
#include <cmath>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#else
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
#else
#include <cmath>
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
namespace oneflow {
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
template<typename T> \
struct func_prefix##Functor;
OF_PP_FOR_EACH_TUPLE(DECLARE_UNARY_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ)
template<typename T>
struct AbsFunctor {
static OF_DEVICE_FUNC T Forward(const T x) {
if (x == T(0))
return T(0);
else
return x < T(0) ? -x : x;
}
static OF_DEVICE_FUNC T Backward(const T x, const T dy) {
if (x == T(0))
return T(0);
else
return x < T(0) ? -dy : dy;
}
};
template<typename T>
struct SignFunctor {
static OF_DEVICE_FUNC T Forward(const T x) { return (T(0) < x) - (x < T(0)); }
static OF_DEVICE_FUNC T Backward(const T x, const T dy) { return T(0); }
};
template<>
struct RsqrtFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
#if defined(__CUDACC__)
return rsqrtf(x);
#elif defined(__HIP_DEVICE_COMPILE__)
return rsqrtf(x);
#else
return 1.0f / std::sqrt(x);
#endif
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-1.0f / (2.0f * MATH_FUNC_F(sqrt, x * x * x)));
}
};
template<>
struct RsqrtFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
#if defined(__CUDACC__)
return rsqrt(x);
#elif defined(__HIP_DEVICE_COMPILE__)
return rsqrt(x);
#else
return 1.0 / std::sqrt(x);
#endif
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-1.0 / (2.0 * MATH_FUNC_D(sqrt, x * x * x)));
}
};
// float version
template<>
struct AcosFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acos, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * -RsqrtFunctor<float>::Forward(1.0f - x * x);
}
};
template<>
struct AcoshFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acosh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(x * x - 1.0f);
}
};
template<>
struct AsinFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asin, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(1.0f - x * x);
}
};
template<>
struct AsinhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asinh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(1.0f + x * x);
}
};
template<>
struct AtanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atan, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (1.0f + x * x));
}
};
template<>
struct AtanhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atanh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (1.0f - x * x));
}
};
template<>
struct NotEqualZeroFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct CeilFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct CosFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cos, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-MATH_FUNC_F(sin, x));
}
};
template<>
struct CoshFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cosh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(sinh, x);
}
};
template<>
struct ErfFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erf, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * 2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ErfcFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erfc, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * -2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ExpFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(exp, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(exp, x);
}
};
template<>
struct Expm1Functor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(expm1, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(exp, x);
}
};
template<>
struct FloorFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(floor, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct LgammaFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(lgamma, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
// TODO(chengcheng): return: dy * digamma(x)
assert(false);
return 0.0f;
}
};
template<>
struct LogFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); }
};
template<>
struct Log2Functor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f)));
}
};
template<>
struct Log1pFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (x + 1.0f));
}
};
template<>
struct LogSigmoidFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
return -MATH_FUNC_F(log, (1.0f + MATH_FUNC_F(exp, -x)));
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (MATH_FUNC_F(exp, x) + 1.0f));
}
};
template<>
struct NegativeFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return -x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return -dy; }
};
template<>
struct ReciprocalFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-1.0f / (x * x));
}
};
template<>
struct ReciprocalNoNanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
if (fabsf(x) <= 0.0f) { return 0.0f; }
return 1.0f / x;
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
if (fabsf(x) <= 0.0f) { return 0.0f; }
return dy * (-1.0f / (x * x));
}
};
template<>
struct RintFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(rint, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct RoundFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(nearbyint, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct SigmoidFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
return 1.0f / (1.0f + MATH_FUNC_F(exp, -x));
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
float y = 1.0f / (1.0f + MATH_FUNC_F(exp, -x));
return dy * (y * (1.0f - y));
}
};
template<>
struct SinFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sin, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(cos, x);
}
};
template<>
struct SinhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sinh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(cosh, x);
}
};
template<>
struct SqrtFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sqrt, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * 0.5f / MATH_FUNC_F(sqrt, x);
}
};
template<>
struct SquareFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x * x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * x; }
};
template<>
struct TanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(tan, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (MATH_FUNC_F(cos, x) * MATH_FUNC_F(cos, x)));
}
};
// double version
template<>
struct AcosFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acos, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -RsqrtFunctor<double>::Forward(1.0 - x * x);
}
};
template<>
struct AcoshFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acosh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -RsqrtFunctor<double>::Forward(x * x - 1.0);
}
};
template<>
struct AsinFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asin, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * RsqrtFunctor<double>::Forward(1.0 - x * x);
}
};
template<>
struct AsinhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asinh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * RsqrtFunctor<double>::Forward(1.0 + x * x);
}
};
template<>
struct AtanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atan, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (1.0 + x * x));
}
};
template<>
struct AtanhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atanh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (1.0 - x * x));
}
};
template<>
struct NotEqualZeroFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0f; }
};
template<>
struct CeilFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct CosFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cos, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-MATH_FUNC_D(sin, x));
}
};
template<>
struct CoshFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cosh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(sinh, x);
}
};
template<>
struct ErfFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erf, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * 2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ErfcFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erfc, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ExpFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(exp, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(exp, x);
}
};
template<>
struct Expm1Functor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(expm1, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(exp, x);
}
};
template<>
struct FloorFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(floor, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct LgammaFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(lgamma, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
// TODO(chengcheng): return: dy * digamma(x)
assert(false);
return 0.0;
}
};
template<>
struct LogFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); }
};
template<>
struct Log2Functor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0)));
}
};
template<>
struct Log1pFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (x + 1.0));
}
};
template<>
struct LogSigmoidFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
return -MATH_FUNC_D(log, (1.0 + MATH_FUNC_D(exp, -x)));
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (MATH_FUNC_D(exp, x) + 1.0));
}
};
template<>
struct NegativeFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return -x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return -dy; }
};
template<>
struct ReciprocalFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-1.0 / (x * x));
}
};
template<>
struct ReciprocalNoNanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
if (fabs(x) <= 0.0) { return 0.0; }
return 1.0 / x;
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
if (fabs(x) <= 0.0) { return 0.0; }
return dy * (-1.0 / (x * x));
}
};
template<>
struct RintFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(rint, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct RoundFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(nearbyint, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct SigmoidFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
return 1.0 / (1.0 + MATH_FUNC_D(exp, -x));
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
double y = 1.0 / (1.0 + MATH_FUNC_D(exp, -x));
return dy * (y * (1.0 - y));
}
};
template<>
struct SinFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sin, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(cos, x);
}
};
template<>
struct SinhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sinh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(cosh, x);
}
};
template<>
struct SqrtFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sqrt, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (double)0.5 / MATH_FUNC_D(sqrt, x);
}
};
template<>
struct SquareFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x * x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * x; }
};
template<>
struct TanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(tan, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (MATH_FUNC_D(cos, x) * MATH_FUNC_D(cos, x)));
}
};
#if defined(__CUDACC__) || defined(__HIPCC__)
// half version
#define OF_HALF_FUNC __device__ __forceinline__
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
template<>
struct AbsFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
return __hlt(x, GetZeroVal<half>()) ? __hneg(x) : x;
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hlt(x, GetZeroVal<half>()) ? __hneg(dy) : dy;
}
};
template<>
struct AcosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acos, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x)))));
}
};
template<>
struct AcoshFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acosh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hsub(__hmul(x, x), GetOneVal<half>())));
}
};
template<>
struct AsinFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asin, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct AsinhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asinh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hadd(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct AtanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atan, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(GetOneVal<half>(), __hadd(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct AtanhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atanh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(GetOneVal<half>(), __hsub(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct CeilFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hceil(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct NotEqualZeroFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return x != static_cast<half>(0.0); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct CosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hcos(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hsin(x)));
}
};
template<>
struct CoshFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(cosh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, MATH_FUNC_H(sinh, x));
}
};
template<>
struct ErfFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erf, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x))));
}
};
template<>
struct ErfcFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erfc, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(__hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x)))));
}
};
template<>
struct ExpFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hexp(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }
};
template<>
struct Expm1Functor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(expm1, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }
};
template<>
struct FloorFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hfloor(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct LgammaFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(lgamma, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
// TODO(chengcheng): return: dy * digamma(x)
assert(false);
return GetZeroVal<half>();
}
};
template<>
struct LogFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hlog(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); }
};
template<>
struct Log2Functor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO))));
}
};
template<>
struct Log1pFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hadd(x, GetOneVal<half>())));
}
};
template<>
struct LogSigmoidFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
return __hneg(hlog(__hadd(GetOneVal<half>(), hexp(__hneg(x)))));
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hadd(hexp(x), GetOneVal<half>())));
}
};
template<>
struct NegativeFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hneg(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hneg(dy); }
};
template<>
struct ReciprocalFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrcp(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrcp(__hmul(x, x))));
}
};
template<>
struct ReciprocalNoNanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }
return hrcp(x);
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }
return __hmul(dy, __hneg(hrcp(__hmul(x, x))));
}
};
template<>
struct RintFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrint(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct RoundFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(nearbyint, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct RsqrtFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrsqrt(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrcp(__hmul(HALF_VAL_TWO, hsqrt(__hmul(x, __hmul(x, x)))))));
}
};
template<>
struct SigmoidFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
return hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
half y = hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));
return __hmul(dy, __hmul(y, __hsub(GetOneVal<half>(), y)));
}
};
template<>
struct SignFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
if (__hgt(x, GetZeroVal<half>())) { return GetOneVal<half>(); }
if (__hlt(x, GetZeroVal<half>())) { return __hneg(GetOneVal<half>()); }
return GetZeroVal<half>();
}
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct SinFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hsin(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hcos(x)); }
};
template<>
struct SinhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(sinh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, MATH_FUNC_H(cosh, x));
}
};
template<>
struct SqrtFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hsqrt(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(HALF_VAL_HALF, hsqrt(x)));
}
};
template<>
struct SquareFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hmul(x, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hmul(HALF_VAL_TWO, x));
}
};
template<>
struct TanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hdiv(hsin(x), hcos(x)); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hmul(hcos(x), hcos(x))));
}
};
#endif
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#define ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/ops/math_unary_elementwise_seq.h"
#include "oneflow/core/device/cuda_pseudo_half.h"
#if defined(__CUDACC__)
#include <cuda_fp16.h>
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#elif defined(__HIPCC__)
#include <cmath>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#define MATH_FUNC_F(name, x) name##f(x)
#define MATH_FUNC_D(name, x) name(x)
#else
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
#else
#include <cmath>
#define MATH_FUNC_F(name, x) std::name(x)
#define MATH_FUNC_D(name, x) std::name(x)
#endif
namespace oneflow {
#define DECLARE_UNARY_FUNCTOR(math_unary_elementwise_type, func_prefix) \
template<typename T> \
struct func_prefix##Functor;
OF_PP_FOR_EACH_TUPLE(DECLARE_UNARY_FUNCTOR, MATH_UNARY_ELEMENTWISE_FUNC_SEQ)
template<typename T>
struct AbsFunctor {
static OF_DEVICE_FUNC T Forward(const T x) {
if (x == T(0))
return T(0);
else
return x < T(0) ? -x : x;
}
static OF_DEVICE_FUNC T Backward(const T x, const T dy) {
if (x == T(0))
return T(0);
else
return x < T(0) ? -dy : dy;
}
};
template<typename T>
struct SignFunctor {
static OF_DEVICE_FUNC T Forward(const T x) { return (T(0) < x) - (x < T(0)); }
static OF_DEVICE_FUNC T Backward(const T x, const T dy) { return T(0); }
};
template<>
struct RsqrtFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
#if defined(__CUDACC__)
return rsqrtf(x);
#elif defined(__HIP_DEVICE_COMPILE__)
return rsqrtf(x);
#else
return 1.0f / std::sqrt(x);
#endif
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-1.0f / (2.0f * MATH_FUNC_F(sqrt, x * x * x)));
}
};
template<>
struct RsqrtFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
#if defined(__CUDACC__)
return rsqrt(x);
#elif defined(__HIP_DEVICE_COMPILE__)
return rsqrt(x);
#else
return 1.0 / std::sqrt(x);
#endif
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-1.0 / (2.0 * MATH_FUNC_D(sqrt, x * x * x)));
}
};
// float version
template<>
struct AcosFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acos, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * -RsqrtFunctor<float>::Forward(1.0f - x * x);
}
};
template<>
struct AcoshFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(acosh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(x * x - 1.0f);
}
};
template<>
struct AsinFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asin, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(1.0f - x * x);
}
};
template<>
struct AsinhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(asinh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * RsqrtFunctor<float>::Forward(1.0f + x * x);
}
};
template<>
struct AtanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atan, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (1.0f + x * x));
}
};
template<>
struct AtanhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(atanh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (1.0f - x * x));
}
};
template<>
struct NotEqualZeroFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x != 0; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct CeilFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(ceil, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct CosFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cos, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-MATH_FUNC_F(sin, x));
}
};
template<>
struct CoshFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(cosh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(sinh, x);
}
};
template<>
struct ErfFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erf, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * 2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ErfcFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(erfc, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * -2.0f * RsqrtFunctor<float>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ExpFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(exp, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(exp, x);
}
};
template<>
struct Expm1Functor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(expm1, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(exp, x);
}
};
template<>
struct FloorFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(floor, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct LgammaFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(lgamma, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
// TODO(chengcheng): return: dy * digamma(x)
// assert(false);
return 0.0f;
}
};
template<>
struct LogFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); }
};
template<>
struct Log2Functor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f)));
}
};
template<>
struct Log1pFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (x + 1.0f));
}
};
template<>
struct LogSigmoidFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
return -MATH_FUNC_F(log, (1.0f + MATH_FUNC_F(exp, -x)));
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (MATH_FUNC_F(exp, x) + 1.0f));
}
};
template<>
struct NegativeFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return -x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return -dy; }
};
template<>
struct ReciprocalFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return 1.0f / x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (-1.0f / (x * x));
}
};
template<>
struct ReciprocalNoNanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
if (fabsf(x) <= 0.0f) { return 0.0f; }
return 1.0f / x;
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
if (fabsf(x) <= 0.0f) { return 0.0f; }
return dy * (-1.0f / (x * x));
}
};
template<>
struct RintFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(rint, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct RoundFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(nearbyint, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return 0.0f; }
};
template<>
struct SigmoidFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) {
return 1.0f / (1.0f + MATH_FUNC_F(exp, -x));
}
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
float y = 1.0f / (1.0f + MATH_FUNC_F(exp, -x));
return dy * (y * (1.0f - y));
}
};
template<>
struct SinFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sin, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(cos, x);
}
};
template<>
struct SinhFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sinh, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * MATH_FUNC_F(cosh, x);
}
};
template<>
struct SqrtFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(sqrt, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * 0.5f / MATH_FUNC_F(sqrt, x);
}
};
template<>
struct SquareFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return x * x; }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * 2.0f * x; }
};
template<>
struct TanFunctor<float> {
static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(tan, x); }
static OF_DEVICE_FUNC float Backward(const float x, const float dy) {
return dy * (1.0f / (MATH_FUNC_F(cos, x) * MATH_FUNC_F(cos, x)));
}
};
// double version
template<>
struct AcosFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acos, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -RsqrtFunctor<double>::Forward(1.0 - x * x);
}
};
template<>
struct AcoshFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(acosh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -RsqrtFunctor<double>::Forward(x * x - 1.0);
}
};
template<>
struct AsinFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asin, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * RsqrtFunctor<double>::Forward(1.0 - x * x);
}
};
template<>
struct AsinhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(asinh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * RsqrtFunctor<double>::Forward(1.0 + x * x);
}
};
template<>
struct AtanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atan, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (1.0 + x * x));
}
};
template<>
struct AtanhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(atanh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (1.0 - x * x));
}
};
template<>
struct NotEqualZeroFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x != 0; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0f; }
};
template<>
struct CeilFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(ceil, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct CosFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cos, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-MATH_FUNC_D(sin, x));
}
};
template<>
struct CoshFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(cosh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(sinh, x);
}
};
template<>
struct ErfFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erf, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * 2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ErfcFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(erfc, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * -2.0 * RsqrtFunctor<double>::Forward(M_PI) * expf(-x * x);
}
};
template<>
struct ExpFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(exp, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(exp, x);
}
};
template<>
struct Expm1Functor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(expm1, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(exp, x);
}
};
template<>
struct FloorFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(floor, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct LgammaFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(lgamma, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
// TODO(chengcheng): return: dy * digamma(x)
// assert(false);
return 0.0;
}
};
template<>
struct LogFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); }
};
template<>
struct Log2Functor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0)));
}
};
template<>
struct Log1pFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (x + 1.0));
}
};
template<>
struct LogSigmoidFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
return -MATH_FUNC_D(log, (1.0 + MATH_FUNC_D(exp, -x)));
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (MATH_FUNC_D(exp, x) + 1.0));
}
};
template<>
struct NegativeFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return -x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return -dy; }
};
template<>
struct ReciprocalFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return 1.0 / x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (-1.0 / (x * x));
}
};
template<>
struct ReciprocalNoNanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
if (fabs(x) <= 0.0) { return 0.0; }
return 1.0 / x;
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
if (fabs(x) <= 0.0) { return 0.0; }
return dy * (-1.0 / (x * x));
}
};
template<>
struct RintFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(rint, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct RoundFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(nearbyint, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return 0.0; }
};
template<>
struct SigmoidFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) {
return 1.0 / (1.0 + MATH_FUNC_D(exp, -x));
}
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
double y = 1.0 / (1.0 + MATH_FUNC_D(exp, -x));
return dy * (y * (1.0 - y));
}
};
template<>
struct SinFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sin, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(cos, x);
}
};
template<>
struct SinhFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sinh, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * MATH_FUNC_D(cosh, x);
}
};
template<>
struct SqrtFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(sqrt, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (double)0.5 / MATH_FUNC_D(sqrt, x);
}
};
template<>
struct SquareFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return x * x; }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * 2.0 * x; }
};
template<>
struct TanFunctor<double> {
static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(tan, x); }
static OF_DEVICE_FUNC double Backward(const double x, const double dy) {
return dy * (1.0 / (MATH_FUNC_D(cos, x) * MATH_FUNC_D(cos, x)));
}
};
#if defined(__CUDACC__) || defined(__HIPCC__)
// half version
#define OF_HALF_FUNC __device__ __forceinline__
#define MATH_FUNC_H(name, x) __float2half(name##f(__half2float(x)))
#define HALF_VAL_HALF __float2half(0.5f)
#define HALF_VAL_TWO __float2half(2.0f)
#define HALF_VAL_2RSQRT_PI __float2half(1.1283791671f)
template<>
struct AbsFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
return __hlt(x, GetZeroVal<half>()) ? __hneg(x) : x;
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hlt(x, GetZeroVal<half>()) ? __hneg(dy) : dy;
}
};
template<>
struct AcosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acos, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x)))));
}
};
template<>
struct AcoshFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(acosh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hsub(__hmul(x, x), GetOneVal<half>())));
}
};
template<>
struct AsinFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asin, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hsub(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct AsinhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(asinh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrsqrt(__hadd(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct AtanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atan, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(GetOneVal<half>(), __hadd(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct AtanhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(atanh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(GetOneVal<half>(), __hsub(GetOneVal<half>(), __hmul(x, x))));
}
};
template<>
struct CeilFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hceil(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct NotEqualZeroFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return x != static_cast<half>(0.0); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct CosFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hcos(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hsin(x)));
}
};
template<>
struct CoshFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(cosh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, MATH_FUNC_H(sinh, x));
}
};
template<>
struct ErfFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erf, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x))));
}
};
template<>
struct ErfcFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(erfc, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(__hmul(HALF_VAL_2RSQRT_PI, hexp(__hmul(__hneg(x), x)))));
}
};
template<>
struct ExpFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hexp(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }
};
template<>
struct Expm1Functor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(expm1, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hexp(x)); }
};
template<>
struct FloorFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hfloor(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct LgammaFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(lgamma, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
// TODO(chengcheng): return: dy * digamma(x)
// assert(false);
return GetZeroVal<half>();
}
};
template<>
struct LogFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hlog(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); }
};
template<>
struct Log2Functor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO))));
}
};
template<>
struct Log1pFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hadd(x, GetOneVal<half>())));
}
};
template<>
struct LogSigmoidFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
return __hneg(hlog(__hadd(GetOneVal<half>(), hexp(__hneg(x)))));
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hadd(hexp(x), GetOneVal<half>())));
}
};
template<>
struct NegativeFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hneg(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hneg(dy); }
};
template<>
struct ReciprocalFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrcp(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrcp(__hmul(x, x))));
}
};
template<>
struct ReciprocalNoNanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }
return hrcp(x);
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
if (__heq(GetZeroVal<half>(), x)) { return GetZeroVal<half>(); }
return __hmul(dy, __hneg(hrcp(__hmul(x, x))));
}
};
template<>
struct RintFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrint(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct RoundFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(nearbyint, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct RsqrtFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hrsqrt(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hneg(hrcp(__hmul(HALF_VAL_TWO, hsqrt(__hmul(x, __hmul(x, x)))))));
}
};
template<>
struct SigmoidFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
return hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));
}
static OF_HALF_FUNC half Backward(const half x, const half dy) {
half y = hrcp(__hadd(GetOneVal<half>(), hexp(__hneg(x))));
return __hmul(dy, __hmul(y, __hsub(GetOneVal<half>(), y)));
}
};
template<>
struct SignFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) {
if (__hgt(x, GetZeroVal<half>())) { return GetOneVal<half>(); }
if (__hlt(x, GetZeroVal<half>())) { return __hneg(GetOneVal<half>()); }
return GetZeroVal<half>();
}
static OF_HALF_FUNC half Backward(const half x, const half dy) { return GetZeroVal<half>(); }
};
template<>
struct SinFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hsin(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hcos(x)); }
};
template<>
struct SinhFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(sinh, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, MATH_FUNC_H(cosh, x));
}
};
template<>
struct SqrtFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return hsqrt(x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hdiv(HALF_VAL_HALF, hsqrt(x)));
}
};
template<>
struct SquareFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hmul(x, x); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, __hmul(HALF_VAL_TWO, x));
}
};
template<>
struct TanFunctor<half> {
static OF_HALF_FUNC half Forward(const half x) { return __hdiv(hsin(x), hcos(x)); }
static OF_HALF_FUNC half Backward(const half x, const half dy) {
return __hmul(dy, hrcp(__hmul(hcos(x), hcos(x))));
}
};
#endif
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_MATH_UNARY_ELEMENTWISE_FUNC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#ifdef OF_ENABLE_PROFILER
#include <roctracer_roctx.h>
#endif // OF_ENABLE_PROFILER
namespace oneflow {
namespace {
#ifdef OF_ENABLE_PROFILER
static thread_local HashMap<std::string, roctx_range_id_t> mark2range_id;
#endif
} // namespace
class NvtxOpKernelState final : public user_op::OpKernelState {
public:
NvtxOpKernelState() : counter_(0) {
#ifndef OF_ENABLE_PROFILER
LOG(WARNING) << "To use NVTX, run cmake with -DBUILD_PROFILER=ON";
#endif
}
~NvtxOpKernelState() override = default;
int64_t counter() const { return counter_; }
void IncreaseCount() { counter_ += 1; }
private:
int64_t counter_;
};
class NvtxStartKernel final : public user_op::OpKernel {
public:
NvtxStartKernel() = default;
~NvtxStartKernel() override = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
return std::make_shared<NvtxOpKernelState>();
}
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,
const user_op::OpKernelCache*) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const ShapeView& in_shape = in->shape_view();
CHECK_EQ(out->shape_view(), in_shape);
const DataType in_data_type = in->data_type();
CHECK_EQ(out->data_type(), in_data_type);
Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<void>(), in->dptr<void>(),
in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));
#ifdef OF_ENABLE_PROFILER
auto* kernel_state = dynamic_cast<NvtxOpKernelState*>(state);
const std::string mark_prefix = ctx->Attr<std::string>("mark_prefix");
const std::string mark = mark_prefix + "-" + std::to_string(kernel_state->counter());
roctx_range_id_t range_id = roctxRangeStartA(mark.c_str());
CHECK(mark2range_id.emplace(mark, range_id).second);
kernel_state->IncreaseCount();
#endif // OF_ENABLE_PROFILER
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("nvtx_start")
.SetCreateFn<NvtxStartKernel>()
.SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)
.SetInplaceProposalFn([](const user_op::InferContext&,
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false));
return Maybe<void>::Ok();
});
class NvtxEndKernel final : public user_op::OpKernel {
public:
NvtxEndKernel() = default;
~NvtxEndKernel() override = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
return std::make_shared<NvtxOpKernelState>();
}
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state,
const user_op::OpKernelCache*) const override {
const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const ShapeView& in_shape = in->shape_view();
CHECK_EQ(out->shape_view(), in_shape);
const DataType in_data_type = in->data_type();
CHECK_EQ(out->data_type(), in_data_type);
#ifdef OF_ENABLE_PROFILER
auto* kernel_state = dynamic_cast<NvtxOpKernelState*>(state);
const std::string mark_prefix = ctx->Attr<std::string>("mark_prefix");
const std::string mark = mark_prefix + "-" + std::to_string(kernel_state->counter());
auto it = mark2range_id.find(mark.c_str());
CHECK(it != mark2range_id.end());
roctx_range_id_t range_id = it->second;
mark2range_id.erase(it);
roctxRangeStop(range_id);
Memcpy<DeviceType::kCUDA>(ctx->stream(), out->mut_dptr<void>(), in->dptr<void>(),
in_shape.elem_cnt() * GetSizeOfDataType(in_data_type));
kernel_state->IncreaseCount();
#endif
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("nvtx_end")
.SetCreateFn<NvtxEndKernel>()
.SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA)
.SetInplaceProposalFn([](const user_op::InferContext&,
user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe<void> {
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false));
return Maybe<void>::Ok();
});
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/eager/call_context.h"
namespace oneflow {
namespace one {
class ConsistentTensorInferResult;
using ArgVec = std::vector<std::pair<std::string, int32_t>>;
using EagerBlobObjectListRawPtr = const std::vector<std::shared_ptr<vm::EagerBlobObject>>*;
using ConsistentTensorInferResultRawPtr = const ConsistentTensorInferResult*;
class ZeroCopyBaseContextHelper {
public:
ZeroCopyBaseContextHelper(const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {}
#define RETURN_IF_FOUND(inputs, outputs, post_action) \
int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \
arg_name, index); \
if (i >= 0) { return (inputs).at(i) post_action; } \
i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \
index); \
if (i >= 0) { return (outputs).at(i) post_action; }
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
const int32_t index) const {
RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get());
return nullptr;
}
user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
const int32_t index) const {
RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get());
if (arg_name == "tmp_buffer" && index == 0) { return call_ctx->mut_tmp_tensor(); }
return nullptr;
}
const ConsistentTensorMeta* ConsistentTensorMeta4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
const int32_t index) const {
const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result();
RETURN_IF_FOUND(consistent_tensor_infer_result->input_tensor_metas(),
consistent_tensor_infer_result->output_tensor_metas(),
.shared_from_symbol().get());
return nullptr;
}
Optional<Symbol<ParallelDesc>> parallel_desc(eager::CallContext* call_ctx) const {
const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result();
if (!consistent_tensor_infer_result) { return Optional<Symbol<ParallelDesc>>(); }
if (!consistent_tensor_infer_result->input_tensor_metas().empty()) {
return consistent_tensor_infer_result->input_tensor_metas().at(0)->parallel_desc();
} else if (!consistent_tensor_infer_result->output_tensor_metas().empty()) {
return consistent_tensor_infer_result->output_tensor_metas().at(0)->parallel_desc();
} else {
UNIMPLEMENTED();
return Optional<Symbol<ParallelDesc>>();
}
}
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
const auto& parallel_desc = this->parallel_desc(call_ctx);
if (parallel_desc.has_value()) {
const auto& parallel_desc_symbol = CHECK_JUST(parallel_desc);
return *CHECK_JUST(GetParallelContext4CurrentProcessCtx(parallel_desc_symbol));
} else {
static ParallelContext single_device_parallel_ctx(MakeSingleDeviceParallelCtx());
return single_device_parallel_ctx;
}
}
const ArgVec& inputs() const { return input_arg_tuple_->indexed_arg_name_and_index(); }
const ArgVec& outputs() const { return output_arg_tuple_->indexed_arg_name_and_index(); }
private:
static int32_t TryGetTensorTupleIndex(const std::unordered_map<std::string, std::vector<int32_t>>&
arg_name2bn_index2tensor_tuple_index,
const std::string& arg_name, const int32_t arg_index) {
auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name);
if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); }
return -1;
}
static ParallelContext MakeSingleDeviceParallelCtx() {
ParallelContext single_device_parallel_ctx;
single_device_parallel_ctx.set_parallel_id(0);
single_device_parallel_ctx.set_parallel_num(1);
return single_device_parallel_ctx;
}
std::shared_ptr<const ArgTuple> input_arg_tuple_;
std::shared_ptr<const ArgTuple> output_arg_tuple_;
};
class UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper {
public:
UserKernelBaseContextHelper(DeviceType device_type,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {}
~UserKernelBaseContextHelper() = default;
DeviceType device_type() const { return device_type_; }
const JobDesc& job_desc() const {
UNIMPLEMENTED();
return *(const JobDesc*)nullptr;
}
private:
const DeviceType device_type_;
};
class UserOpInferContextHelper final {
public:
UserOpInferContextHelper(const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
zero_copy_base_ctx_helper_(input_arg_tuple, output_arg_tuple) {}
~UserOpInferContextHelper() = default;
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
UNIMPLEMENTED();
return nullptr;
}
const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const {
return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index));
}
user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *Shape4ArgNameAndIndex(call_ctx, arg_name, index);
}
Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return Shape4ArgNameAndIndex(call_ctx, arg_name, index);
}
Shape* Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape();
}
const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *Stride4ArgNameAndIndex(call_ctx, arg_name, index);
}
Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return Stride4ArgNameAndIndex(call_ctx, arg_name, index);
}
Stride* Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride();
}
const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *Dtype4ArgNameAndIndex(call_ctx, arg_name, index);
}
DataType* OutputDType(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return Dtype4ArgNameAndIndex(call_ctx, arg_name, index);
}
DataType* Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type();
}
bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);
}
bool* OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);
}
bool* IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic();
}
const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return zero_copy_base_ctx_helper_.outputs(); }
const JobDesc* job_desc() const {
UNIMPLEMENTED();
return nullptr;
}
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return zero_copy_base_ctx_helper_.parallel_ctx(call_ctx);
}
const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {
return *CHECK_JUST(zero_copy_base_ctx_helper_.parallel_desc(call_ctx));
}
const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const {
const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);
CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);
return nd_sbp.sbp_parallel(0);
}
const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *CHECK_NOTNULL(zero_copy_base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(
call_ctx, arg_name, index))
->nd_sbp();
}
int64_t parallel_num(eager::CallContext* call_ctx) const {
return parallel_ctx(call_ctx).parallel_num();
}
const std::string& input(const std::string& arg_name, int32_t index) const {
return user_op_conf().input(arg_name, index);
}
const std::string& output(const std::string& arg_name, int32_t index) const {
return user_op_conf().output(arg_name, index);
}
bool has_input(const std::string& arg_name, int32_t index) const {
return user_op_conf().has_input(arg_name, index);
}
bool has_output(const std::string& arg_name, int32_t index) const {
return user_op_conf().has_output(arg_name, index);
}
int32_t input_size(const std::string& arg_name) const {
return user_op_conf().input_size(arg_name);
}
int32_t output_size(const std::string& arg_name) const {
return user_op_conf().output_size(arg_name);
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
private:
user_op::TensorDesc* NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; }
return tensor_desc;
}
const user_op::UserOpConfWrapper* user_op_conf_;
ZeroCopyBaseContextHelper zero_copy_base_ctx_helper_;
};
class UserOpInferContext : public user_op::InferContext {
public:
UserOpInferContext(const UserOpInferContextHelper* helper, eager::CallContext* call_ctx)
: helper_(helper), call_ctx_(call_ctx) {}
~UserOpInferContext() override = default;
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,
int32_t index) const override {
return helper_->InputTensorDesc(call_ctx_, arg_name, index);
}
user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override {
return helper_->OutputTensorDesc(call_ctx_, arg_name, index);
}
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const Shape& InputShape(const std::string& arg_name, int32_t index) const override {
return helper_->InputShape(call_ctx_, arg_name, index);
}
Shape* OutputShape(const std::string& arg_name, int32_t index) override {
return helper_->OutputShape(call_ctx_, arg_name, index);
}
Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const Stride& InputStride(const std::string& arg_name, int32_t index) const override {
return helper_->InputStride(call_ctx_, arg_name, index);
}
Stride* OutputStride(const std::string& arg_name, int32_t index) override {
return helper_->OutputStride(call_ctx_, arg_name, index);
}
Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const DataType& InputDType(const std::string& arg_name, int32_t index) const override {
return helper_->InputDType(call_ctx_, arg_name, index);
}
DataType* OutputDType(const std::string& arg_name, int32_t index) override {
return helper_->OutputDType(call_ctx_, arg_name, index);
}
DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index);
}
bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {
return helper_->InputIsDynamic(call_ctx_, arg_name, index);
}
bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override {
return helper_->OutputIsDynamic(call_ctx_, arg_name, index);
}
bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
const JobDesc* job_desc() const override { return helper_->job_desc(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }
const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {
return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);
}
int64_t parallel_num() const override { return helper_->parallel_num(call_ctx_); }
const std::string& input(const std::string& arg_name, int32_t index) const override {
return helper_->input(arg_name, index);
}
const std::string& output(const std::string& arg_name, int32_t index) const override {
return helper_->output(arg_name, index);
}
bool has_input(const std::string& arg_name, int32_t index) const override {
return helper_->has_input(arg_name, index);
}
bool has_output(const std::string& arg_name, int32_t index) const override {
return helper_->has_output(arg_name, index);
}
int32_t input_size(const std::string& arg_name) const override {
return helper_->input_size(arg_name);
}
int32_t output_size(const std::string& arg_name) const override {
return helper_->output_size(arg_name);
}
const std::string& op_name() const override { return helper_->op_name(); }
const std::string& op_type_name() const override { return helper_->op_type_name(); }
const std::string& op_loc() const override { return helper_->op_loc(); }
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const UserOpInferContextHelper* helper_;
eager::CallContext* call_ctx_;
};
class UserKernelComputeContextHelper final {
public:
UserKernelComputeContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelComputeContextHelper() = default;
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index);
}
ep::Stream* stream(DeviceCtx* device_ctx) const {
CHECK(device_ctx);
return device_ctx->stream();
}
DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
private:
const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_;
};
class UserKernelComputeContext final : public user_op::KernelComputeContext {
public:
UserKernelComputeContext(const UserKernelComputeContextHelper* helper,
eager::CallContext* call_ctx, DeviceCtx* device_ctx)
: helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {}
~UserKernelComputeContext() = default;
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Tensor4ArgNameAndIndex(call_ctx_, arg_name, index);
}
ep::Stream* stream() override { return helper_->stream(device_ctx_); }
DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
private:
const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf();
}
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const UserKernelComputeContextHelper* helper_;
eager::CallContext* call_ctx_;
DeviceCtx* device_ctx_;
};
class UserKernelRegContextHelper final {
public:
UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelRegContextHelper() = default;
DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
private:
const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_;
};
class UserKernelRegContext final : public user_op::KernelRegContext {
public:
UserKernelRegContext(const UserKernelRegContextHelper* helper, eager::CallContext* call_ctx)
: helper_(helper), call_ctx_(call_ctx) {}
~UserKernelRegContext() = default;
DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf();
}
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const UserKernelRegContextHelper* helper_;
eager::CallContext* call_ctx_;
};
class UserKernelInitAndCacheContextHelper final {
public:
UserKernelInitAndCacheContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelInitAndCacheContextHelper() = default;
ep::Stream* stream(DeviceCtx* device_ctx) const {
CHECK(device_ctx);
return device_ctx->stream();
}
DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index);
}
const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const {
const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);
CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);
return nd_sbp.sbp_parallel(0);
}
const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *CHECK_NOTNULL(
base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index))
->nd_sbp();
}
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {
return *CHECK_JUST(base_ctx_helper_.parallel_desc(call_ctx));
}
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
private:
const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_;
};
class UserKernelInitAndCacheContext final : public user_op::KernelInitContext,
public user_op::KernelCacheContext {
public:
UserKernelInitAndCacheContext(const UserKernelInitAndCacheContextHelper* helper,
eager::CallContext* call_ctx, DeviceCtx* device_ctx)
: helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {}
~UserKernelInitAndCacheContext() override = default;
ep::Stream* stream() override { return helper_->stream(device_ctx_); }
DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {
return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf();
}
const UserKernelInitAndCacheContextHelper* helper_;
eager::CallContext* call_ctx_;
DeviceCtx* device_ctx_;
};
namespace {
Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>& op_conf,
const ArgVec& indexed_input_pairs,
const ArgVec& indexed_output_pairs,
std::vector<int64_t>* input_tuple_indexes4const_ibns,
std::vector<int64_t>* input_tuple_indexes4mut_ibns,
std::vector<int64_t>* output_tuple_indexes4mut_obns,
std::vector<int64_t>* output_tuple_indexes4mut2_obns) {
const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val);
ArgModifierSignature arg_modifier_signature;
for (const auto& pair : indexed_input_pairs) {
const std::string ibn = GenRepeatedBn(pair.first, pair.second);
arg_modifier_signature.mutable_ibn2input_blob_modifier()->insert(
{ibn, user_op::InputArgModifier()});
}
for (const auto& pair : indexed_output_pairs) {
const std::string obn = GenRepeatedBn(pair.first, pair.second);
arg_modifier_signature.mutable_obn2output_blob_modifier()->insert(
{obn, user_op::OutputArgModifier()});
}
user_op::UserOpConfWrapper op_conf_wrapper(op_conf);
if (op_reg_val->input_arg_modify_fn) {
user_op::GetInputArgModifier GetInputArgModifierFn =
[&arg_modifier_signature](const std::string& in_arg_name,
int32_t in_arg_index) -> user_op::InputArgModifier* {
const std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index);
auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier();
return &map->at(ibn);
};
JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper));
}
if (op_reg_val->output_arg_modify_fn) {
user_op::GetOutputArgModifier GetOutputArgModifierFn =
[&arg_modifier_signature](const std::string& in_arg_name,
int32_t in_arg_index) -> user_op::OutputArgModifier* {
const std::string obn = GenRepeatedBn(in_arg_name, in_arg_index);
auto* map = arg_modifier_signature.mutable_obn2output_blob_modifier();
return &map->at(obn);
};
JUST(op_reg_val->output_arg_modify_fn(GetOutputArgModifierFn, op_conf_wrapper));
}
for (int i = 0; i < indexed_input_pairs.size(); i++) {
const auto& pair = indexed_input_pairs.at(i);
const std::string ibn = GenRepeatedBn(pair.first, pair.second);
if (arg_modifier_signature.ibn2input_blob_modifier().at(ibn).is_mutable()) {
input_tuple_indexes4mut_ibns->emplace_back(i);
} else {
input_tuple_indexes4const_ibns->emplace_back(i);
}
}
for (int i = 0; i < indexed_output_pairs.size(); i++) {
const auto& pair = indexed_output_pairs.at(i);
const std::string obn = GenRepeatedBn(pair.first, pair.second);
if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) {
output_tuple_indexes4mut_obns->emplace_back(i);
} else {
output_tuple_indexes4mut2_obns->emplace_back(i);
}
}
return Maybe<void>::Ok();
}
} // namespace
/* static */ Maybe<StatefulOpKernel> StatefulOpKernel::New(
const std::shared_ptr<OperatorConf>& op_conf, const Symbol<Stream>& stream,
const AttrMap& base_attrs, const std::shared_ptr<const ParallelDesc>& parallel_desc,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) {
auto opkernel = std::shared_ptr<StatefulOpKernel>(new StatefulOpKernel());
opkernel->base_attrs_ = base_attrs;
opkernel->op_conf_ = op_conf;
opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf));
opkernel->stream_ = stream;
opkernel->input_arg_tuple_ = input_arg_tuple;
opkernel->output_arg_tuple_ = output_arg_tuple;
opkernel->need_check_mem_case_ = true;
const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag()));
const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get();
opkernel->op_infer_ctx_helper_.reset(
new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper(
device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_,
opkernel->output_arg_tuple_));
opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper(
device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->reg_ctx_helper_.reset(
new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val);
if (op_reg_val->logical_tensor_desc_infer_fn) {
opkernel->tensor_desc_infer_fn_ = op_reg_val->logical_tensor_desc_infer_fn;
} else {
return Error::UnimplementedError();
}
opkernel->data_type_infer_fn_ = op_reg_val->data_type_infer_fn;
JUST(InitTensorTupleIndexes4Bns(
op_conf, input_arg_tuple->indexed_arg_name_and_index(),
output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_,
&opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_,
&opkernel->output_tuple_indexes4mut2_obns_));
return opkernel;
}
StatefulOpKernel::~StatefulOpKernel() = default;
size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx,
const user_op::OpKernel* user_opkernel) const {
UserOpInferContext op_infer_ctx(op_infer_ctx_helper_.get(), call_ctx);
const auto& InferTmpSizeFn = GetInferTmpSizeFn(user_opkernel);
return InferTmpSizeFn(&op_infer_ctx);
}
Maybe<void> StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx,
const user_op::OpKernel** user_opkernel,
bool* need_temp_storage) {
OF_PROFILER_RANGE_GUARD("ChooseOpKernel");
DataType primary_dtype = kInvalidDataType;
const auto& inputs = call_ctx->inputs();
const auto& outputs = call_ctx->outputs();
if (likely(!inputs->empty())) {
primary_dtype = (*inputs)[0]->data_type();
} else if (likely(!outputs->empty())) {
primary_dtype = (*outputs)[0]->data_type();
} else {
// do nothing
}
UserKernelRegContext reg_ctx(reg_ctx_helper_.get(), call_ctx);
for (const auto& pair : dtype2cached_kernels_[primary_dtype]) {
if (likely(pair.first->is_matched_hob->get(reg_ctx))) {
*need_temp_storage = pair.first->need_temp_storage;
*user_opkernel = pair.second.get();
return Maybe<void>::Ok();
}
}
OF_PROFILER_RANGE_GUARD("fallback");
const auto& op_type_name = user_op_conf_->op_type_name();
const auto* kernel_reg_val =
JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_type_name, reg_ctx));
CHECK_NOTNULL(kernel_reg_val);
auto* kernel = kernel_reg_val->create_fn();
dtype2cached_kernels_[primary_dtype].push_back(
{kernel_reg_val, std::shared_ptr<const user_op::OpKernel>(kernel)});
infer_tmp_size_fn_map_.emplace(kernel, &kernel_reg_val->infer_tmp_size_fn);
*need_temp_storage = kernel_reg_val->need_temp_storage;
*user_opkernel = kernel;
return Maybe<void>::Ok();
}
void StatefulOpKernel::TryInitOpKernelStateAndCache(eager::CallContext* call_ctx,
DeviceCtx* device_ctx,
const user_op::OpKernel* op_kernel,
user_op::OpKernelState** state,
user_op::OpKernelCache** cache) {
UserKernelInitAndCacheContext init_and_cache_ctx(init_and_cache_ctx_helper_.get(), call_ctx,
device_ctx);
if (state != nullptr) {
auto it = op_kernel_state_map_.find(op_kernel);
if (it != op_kernel_state_map_.end()) {
*state = it->second.get();
} else {
auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx);
op_kernel_state_map_.emplace(op_kernel, created_state);
*state = created_state.get();
}
}
{
auto& cache_in_map = op_kernel_cache_map_[op_kernel];
op_kernel->InitOpKernelCacheWithFlags(&init_and_cache_ctx,
user_op::OpKernelCache::kAllMayChanged, &cache_in_map);
*cache = cache_in_map.get();
}
}
const user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn(
const user_op::OpKernel* op_kernel) const {
return *infer_tmp_size_fn_map_.at(op_kernel);
}
user_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const {
return tensor_desc_infer_fn_;
}
user_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; }
void StatefulOpKernel::Compute(eager::CallContext* call_ctx, DeviceCtx* device_ctx,
const user_op::OpKernel* user_opkernel,
user_op::OpKernelState* state,
const user_op::OpKernelCache* cache) const {
UserKernelComputeContext compute_context(compute_ctx_helper_.get(), call_ctx, device_ctx);
auto* compute_ctx = &compute_context;
OF_PROFILER_RANGE_GUARD("Compute");
if (Singleton<profiler::ProfileManager>::Get()) {
#if defined(WITH_CUDA)
const auto CalMemorySize = [compute_ctx](const one::ArgVec& args) -> int64_t {
const auto Func = [compute_ctx](int64_t mem_size, const auto& pair) {
const auto tensor = compute_ctx->Tensor4ArgNameAndIndex(pair.first, pair.second);
return mem_size + tensor->shape_view().elem_cnt() * GetSizeOfDataType(tensor->data_type());
};
return std::accumulate(args.begin(), args.end(), static_cast<int64_t>(0), Func);
};
#endif
auto er_guard = CHECK_JUST(profiler::EventRecorder::CreateKernelEventRecorder(
op_type_name(),
#if defined(WITH_CUDA)
[compute_ctx, CalMemorySize]() -> int64_t {
return CalMemorySize(compute_ctx->inputs()) + CalMemorySize(compute_ctx->outputs());
},
#endif
[compute_ctx]() -> std::vector<ShapeView> {
std::vector<ShapeView> shapes;
for (const auto& pair : compute_ctx->inputs()) {
shapes.emplace_back(
compute_ctx->TensorDesc4ArgNameAndIndex(pair.first, pair.second)->shape());
}
return shapes;
}));
user_opkernel->Compute(compute_ctx, state, cache);
} else {
user_opkernel->Compute(compute_ctx, state, cache);
}
}
} // namespace one
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/framework/attr_value_accessor.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/eager/eager_blob_object.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/framework/consistent_tensor_infer_cache.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/profiler/profiler.h"
#include "oneflow/core/profiler/profile_manager.h"
#include "oneflow/core/profiler/event_recorder.h"
#include "oneflow/core/eager/call_context.h"
namespace oneflow {
namespace one {
class ConsistentTensorInferResult;
using ArgVec = std::vector<std::pair<std::string, int32_t>>;
using EagerBlobObjectListRawPtr = const std::vector<std::shared_ptr<vm::EagerBlobObject>>*;
using ConsistentTensorInferResultRawPtr = const ConsistentTensorInferResult*;
class ZeroCopyBaseContextHelper {
public:
ZeroCopyBaseContextHelper(const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: input_arg_tuple_(input_arg_tuple), output_arg_tuple_(output_arg_tuple) {}
#define RETURN_IF_FOUND(inputs, outputs, post_action) \
int32_t i = TryGetTensorTupleIndex(input_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), \
arg_name, index); \
if (i >= 0) { return (inputs).at(i) post_action; } \
i = TryGetTensorTupleIndex(output_arg_tuple_->arg_name2bn_index2tensor_tuple_index(), arg_name, \
index); \
if (i >= 0) { return (outputs).at(i) post_action; }
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
const int32_t index) const {
RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get());
return nullptr;
}
user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
const int32_t index) const {
RETURN_IF_FOUND(*call_ctx->inputs(), *call_ctx->outputs(), .get());
if (arg_name == "tmp_buffer" && index == 0) { return call_ctx->mut_tmp_tensor(); }
return nullptr;
}
const ConsistentTensorMeta* ConsistentTensorMeta4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
const int32_t index) const {
const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result();
RETURN_IF_FOUND(consistent_tensor_infer_result->input_tensor_metas(),
consistent_tensor_infer_result->output_tensor_metas(),
.shared_from_symbol().get());
return nullptr;
}
Optional<Symbol<ParallelDesc>> parallel_desc(eager::CallContext* call_ctx) const {
const auto& consistent_tensor_infer_result = call_ctx->consistent_tensor_infer_result();
if (!consistent_tensor_infer_result) { return Optional<Symbol<ParallelDesc>>(); }
if (!consistent_tensor_infer_result->input_tensor_metas().empty()) {
return consistent_tensor_infer_result->input_tensor_metas().at(0)->parallel_desc();
} else if (!consistent_tensor_infer_result->output_tensor_metas().empty()) {
return consistent_tensor_infer_result->output_tensor_metas().at(0)->parallel_desc();
} else {
UNIMPLEMENTED();
return Optional<Symbol<ParallelDesc>>();
}
}
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
const auto& parallel_desc = this->parallel_desc(call_ctx);
if (parallel_desc.has_value()) {
const auto& parallel_desc_symbol = CHECK_JUST(parallel_desc);
return *CHECK_JUST(GetParallelContext4CurrentProcessCtx(parallel_desc_symbol));
} else {
static ParallelContext single_device_parallel_ctx(MakeSingleDeviceParallelCtx());
return single_device_parallel_ctx;
}
}
const ArgVec& inputs() const { return input_arg_tuple_->indexed_arg_name_and_index(); }
const ArgVec& outputs() const { return output_arg_tuple_->indexed_arg_name_and_index(); }
private:
static int32_t TryGetTensorTupleIndex(const std::unordered_map<std::string, std::vector<int32_t>>&
arg_name2bn_index2tensor_tuple_index,
const std::string& arg_name, const int32_t arg_index) {
auto it = arg_name2bn_index2tensor_tuple_index.find(arg_name);
if (it != arg_name2bn_index2tensor_tuple_index.end()) { return it->second.at(arg_index); }
return -1;
}
static ParallelContext MakeSingleDeviceParallelCtx() {
ParallelContext single_device_parallel_ctx;
single_device_parallel_ctx.set_parallel_id(0);
single_device_parallel_ctx.set_parallel_num(1);
return single_device_parallel_ctx;
}
std::shared_ptr<const ArgTuple> input_arg_tuple_;
std::shared_ptr<const ArgTuple> output_arg_tuple_;
};
class UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper {
public:
UserKernelBaseContextHelper(DeviceType device_type,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {}
~UserKernelBaseContextHelper() = default;
DeviceType device_type() const { return device_type_; }
const JobDesc& job_desc() const {
UNIMPLEMENTED();
return *(const JobDesc*)nullptr;
}
private:
const DeviceType device_type_;
};
class UserOpInferContextHelper final {
public:
UserOpInferContextHelper(const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
zero_copy_base_ctx_helper_(input_arg_tuple, output_arg_tuple) {}
~UserOpInferContextHelper() = default;
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
UNIMPLEMENTED();
return nullptr;
}
const user_op::TensorDesc& InputTensorDesc(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const {
return *CHECK_NOTNULL(TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index));
}
user_op::TensorDesc* OutputTensorDesc(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return zero_copy_base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
const Shape& InputShape(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *Shape4ArgNameAndIndex(call_ctx, arg_name, index);
}
Shape* OutputShape(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return Shape4ArgNameAndIndex(call_ctx, arg_name, index);
}
Shape* Shape4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_shape();
}
const Stride& InputStride(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *Stride4ArgNameAndIndex(call_ctx, arg_name, index);
}
Stride* OutputStride(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return Stride4ArgNameAndIndex(call_ctx, arg_name, index);
}
Stride* Stride4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_stride();
}
const DataType& InputDType(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *Dtype4ArgNameAndIndex(call_ctx, arg_name, index);
}
DataType* OutputDType(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return Dtype4ArgNameAndIndex(call_ctx, arg_name, index);
}
DataType* Dtype4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_data_type();
}
bool InputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);
}
bool* OutputIsDynamic(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return IsDynamic4ArgNameAndIndex(call_ctx, arg_name, index);
}
bool* IsDynamic4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return NonNullTensorDesc4ArgNameAndIndex(call_ctx, arg_name, index)->mut_is_dynamic();
}
const ArgVec& inputs() const { return zero_copy_base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return zero_copy_base_ctx_helper_.outputs(); }
const JobDesc* job_desc() const {
UNIMPLEMENTED();
return nullptr;
}
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return zero_copy_base_ctx_helper_.parallel_ctx(call_ctx);
}
const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {
return *CHECK_JUST(zero_copy_base_ctx_helper_.parallel_desc(call_ctx));
}
const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const {
const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);
CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);
return nd_sbp.sbp_parallel(0);
}
const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *CHECK_NOTNULL(zero_copy_base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(
call_ctx, arg_name, index))
->nd_sbp();
}
int64_t parallel_num(eager::CallContext* call_ctx) const {
return parallel_ctx(call_ctx).parallel_num();
}
const std::string& input(const std::string& arg_name, int32_t index) const {
return user_op_conf().input(arg_name, index);
}
const std::string& output(const std::string& arg_name, int32_t index) const {
return user_op_conf().output(arg_name, index);
}
bool has_input(const std::string& arg_name, int32_t index) const {
return user_op_conf().has_input(arg_name, index);
}
bool has_output(const std::string& arg_name, int32_t index) const {
return user_op_conf().has_output(arg_name, index);
}
int32_t input_size(const std::string& arg_name) const {
return user_op_conf().input_size(arg_name);
}
int32_t output_size(const std::string& arg_name) const {
return user_op_conf().output_size(arg_name);
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
private:
user_op::TensorDesc* NonNullTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
user_op::TensorDesc* tensor_desc = TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
if (!tensor_desc) { LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; }
return tensor_desc;
}
const user_op::UserOpConfWrapper* user_op_conf_;
ZeroCopyBaseContextHelper zero_copy_base_ctx_helper_;
};
class UserOpInferContext : public user_op::InferContext {
public:
UserOpInferContext(const UserOpInferContextHelper* helper, eager::CallContext* call_ctx)
: helper_(helper), call_ctx_(call_ctx) {}
~UserOpInferContext() override = default;
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const user_op::TensorDesc& InputTensorDesc(const std::string& arg_name,
int32_t index) const override {
return helper_->InputTensorDesc(call_ctx_, arg_name, index);
}
user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override {
return helper_->OutputTensorDesc(call_ctx_, arg_name, index);
}
user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const Shape& InputShape(const std::string& arg_name, int32_t index) const override {
return helper_->InputShape(call_ctx_, arg_name, index);
}
Shape* OutputShape(const std::string& arg_name, int32_t index) override {
return helper_->OutputShape(call_ctx_, arg_name, index);
}
Shape* Shape4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Shape4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const Stride& InputStride(const std::string& arg_name, int32_t index) const override {
return helper_->InputStride(call_ctx_, arg_name, index);
}
Stride* OutputStride(const std::string& arg_name, int32_t index) override {
return helper_->OutputStride(call_ctx_, arg_name, index);
}
Stride* Stride4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Stride4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const DataType& InputDType(const std::string& arg_name, int32_t index) const override {
return helper_->InputDType(call_ctx_, arg_name, index);
}
DataType* OutputDType(const std::string& arg_name, int32_t index) override {
return helper_->OutputDType(call_ctx_, arg_name, index);
}
DataType* Dtype4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Dtype4ArgNameAndIndex(call_ctx_, arg_name, index);
}
bool InputIsDynamic(const std::string& arg_name, int32_t index) const override {
return helper_->InputIsDynamic(call_ctx_, arg_name, index);
}
bool* OutputIsDynamic(const std::string& arg_name, int32_t index) override {
return helper_->OutputIsDynamic(call_ctx_, arg_name, index);
}
bool* IsDynamic4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->IsDynamic4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
const JobDesc* job_desc() const override { return helper_->job_desc(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }
const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {
return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);
}
int64_t parallel_num() const override { return helper_->parallel_num(call_ctx_); }
const std::string& input(const std::string& arg_name, int32_t index) const override {
return helper_->input(arg_name, index);
}
const std::string& output(const std::string& arg_name, int32_t index) const override {
return helper_->output(arg_name, index);
}
bool has_input(const std::string& arg_name, int32_t index) const override {
return helper_->has_input(arg_name, index);
}
bool has_output(const std::string& arg_name, int32_t index) const override {
return helper_->has_output(arg_name, index);
}
int32_t input_size(const std::string& arg_name) const override {
return helper_->input_size(arg_name);
}
int32_t output_size(const std::string& arg_name) const override {
return helper_->output_size(arg_name);
}
const std::string& op_name() const override { return helper_->op_name(); }
const std::string& op_type_name() const override { return helper_->op_type_name(); }
const std::string& op_loc() const override { return helper_->op_loc(); }
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const UserOpInferContextHelper* helper_;
eager::CallContext* call_ctx_;
};
class UserKernelComputeContextHelper final {
public:
UserKernelComputeContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelComputeContextHelper() = default;
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
user_op::Tensor* Tensor4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.Tensor4ArgNameAndIndex(call_ctx, arg_name, index);
}
ep::Stream* stream(DeviceCtx* device_ctx) const {
CHECK(device_ctx);
return device_ctx->stream();
}
DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
private:
const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_;
};
class UserKernelComputeContext final : public user_op::KernelComputeContext {
public:
UserKernelComputeContext(const UserKernelComputeContextHelper* helper,
eager::CallContext* call_ctx, DeviceCtx* device_ctx)
: helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {}
~UserKernelComputeContext() = default;
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
user_op::Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t index) override {
return helper_->Tensor4ArgNameAndIndex(call_ctx_, arg_name, index);
}
ep::Stream* stream() override { return helper_->stream(device_ctx_); }
DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
private:
const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf();
}
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const UserKernelComputeContextHelper* helper_;
eager::CallContext* call_ctx_;
DeviceCtx* device_ctx_;
};
class UserKernelRegContextHelper final {
public:
UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelRegContextHelper() = default;
DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
private:
const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_;
};
class UserKernelRegContext final : public user_op::KernelRegContext {
public:
UserKernelRegContext(const UserKernelRegContextHelper* helper, eager::CallContext* call_ctx)
: helper_(helper), call_ctx_(call_ctx) {}
~UserKernelRegContext() = default;
DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf();
}
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const UserKernelRegContextHelper* helper_;
eager::CallContext* call_ctx_;
};
class UserKernelInitAndCacheContextHelper final {
public:
UserKernelInitAndCacheContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelInitAndCacheContextHelper() = default;
ep::Stream* stream(DeviceCtx* device_ctx) const {
CHECK(device_ctx);
return device_ctx->stream();
}
DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.TensorDesc4ArgNameAndIndex(call_ctx, arg_name, index);
}
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name,
int32_t index) const {
return base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index);
}
const SbpParallel& SbpParallel4ArgNameAndIndex(eager::CallContext* call_ctx,
const std::string& arg_name, int32_t index) const {
const auto& nd_sbp = NdSbp4ArgNameAndIndex(call_ctx, arg_name, index);
CHECK_EQ(nd_sbp.sbp_parallel_size(), 1);
return nd_sbp.sbp_parallel(0);
}
const NdSbp& NdSbp4ArgNameAndIndex(eager::CallContext* call_ctx, const std::string& arg_name,
int32_t index) const {
return *CHECK_NOTNULL(
base_ctx_helper_.ConsistentTensorMeta4ArgNameAndIndex(call_ctx, arg_name, index))
->nd_sbp();
}
const ArgVec& inputs() const { return base_ctx_helper_.inputs(); }
const ArgVec& outputs() const { return base_ctx_helper_.outputs(); }
const ParallelDesc& parallel_desc(eager::CallContext* call_ctx) const {
return *CHECK_JUST(base_ctx_helper_.parallel_desc(call_ctx));
}
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(eager::CallContext* call_ctx,
const std::string& attr_name) const {
return call_ctx->composed_attrs().Attr4Name(attr_name);
}
const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
private:
const user_op::UserOpConfWrapper* user_op_conf_;
UserKernelBaseContextHelper base_ctx_helper_;
};
class UserKernelInitAndCacheContext final : public user_op::KernelInitContext,
public user_op::KernelCacheContext {
public:
UserKernelInitAndCacheContext(const UserKernelInitAndCacheContextHelper* helper,
eager::CallContext* call_ctx, DeviceCtx* device_ctx)
: helper_(helper), call_ctx_(call_ctx), device_ctx_(device_ctx) {}
~UserKernelInitAndCacheContext() override = default;
ep::Stream* stream() override { return helper_->stream(device_ctx_); }
DeviceType device_type() const override { return helper_->device_type(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->TensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const user_op::TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->LogicalTensorDesc4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const SbpParallel& SbpParallel4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
return helper_->SbpParallel4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const NdSbp& NdSbp4ArgNameAndIndex(const std::string& arg_name, int32_t index) const override {
return helper_->NdSbp4ArgNameAndIndex(call_ctx_, arg_name, index);
}
const ArgVec& inputs() const override { return helper_->inputs(); }
const ArgVec& outputs() const override { return helper_->outputs(); }
const ParallelDesc& parallel_desc() const override { return helper_->parallel_desc(call_ctx_); }
private:
const std::shared_ptr<const user_op::AttrVal>& Attr4Name(
const std::string& attr_name) const override {
return helper_->Attr4Name(call_ctx_, attr_name);
}
const user_op::UserOpConfWrapper& user_op_conf() const override {
return helper_->user_op_conf();
}
const UserKernelInitAndCacheContextHelper* helper_;
eager::CallContext* call_ctx_;
DeviceCtx* device_ctx_;
};
namespace {
Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>& op_conf,
const ArgVec& indexed_input_pairs,
const ArgVec& indexed_output_pairs,
std::vector<int64_t>* input_tuple_indexes4const_ibns,
std::vector<int64_t>* input_tuple_indexes4mut_ibns,
std::vector<int64_t>* output_tuple_indexes4mut_obns,
std::vector<int64_t>* output_tuple_indexes4mut2_obns) {
const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_conf->user_conf().op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val);
ArgModifierSignature arg_modifier_signature;
for (const auto& pair : indexed_input_pairs) {
const std::string ibn = GenRepeatedBn(pair.first, pair.second);
arg_modifier_signature.mutable_ibn2input_blob_modifier()->insert(
{ibn, user_op::InputArgModifier()});
}
for (const auto& pair : indexed_output_pairs) {
const std::string obn = GenRepeatedBn(pair.first, pair.second);
arg_modifier_signature.mutable_obn2output_blob_modifier()->insert(
{obn, user_op::OutputArgModifier()});
}
user_op::UserOpConfWrapper op_conf_wrapper(op_conf);
if (op_reg_val->input_arg_modify_fn) {
user_op::GetInputArgModifier GetInputArgModifierFn =
[&arg_modifier_signature](const std::string& in_arg_name,
int32_t in_arg_index) -> user_op::InputArgModifier* {
const std::string ibn = GenRepeatedBn(in_arg_name, in_arg_index);
auto* map = arg_modifier_signature.mutable_ibn2input_blob_modifier();
return &map->at(ibn);
};
JUST(op_reg_val->input_arg_modify_fn(GetInputArgModifierFn, op_conf_wrapper));
}
if (op_reg_val->output_arg_modify_fn) {
user_op::GetOutputArgModifier GetOutputArgModifierFn =
[&arg_modifier_signature](const std::string& in_arg_name,
int32_t in_arg_index) -> user_op::OutputArgModifier* {
const std::string obn = GenRepeatedBn(in_arg_name, in_arg_index);
auto* map = arg_modifier_signature.mutable_obn2output_blob_modifier();
return &map->at(obn);
};
JUST(op_reg_val->output_arg_modify_fn(GetOutputArgModifierFn, op_conf_wrapper));
}
for (int i = 0; i < indexed_input_pairs.size(); i++) {
const auto& pair = indexed_input_pairs.at(i);
const std::string ibn = GenRepeatedBn(pair.first, pair.second);
if (arg_modifier_signature.ibn2input_blob_modifier().at(ibn).is_mutable()) {
input_tuple_indexes4mut_ibns->emplace_back(i);
} else {
input_tuple_indexes4const_ibns->emplace_back(i);
}
}
for (int i = 0; i < indexed_output_pairs.size(); i++) {
const auto& pair = indexed_output_pairs.at(i);
const std::string obn = GenRepeatedBn(pair.first, pair.second);
if (arg_modifier_signature.obn2output_blob_modifier().at(obn).header_infered_before_compute()) {
output_tuple_indexes4mut_obns->emplace_back(i);
} else {
output_tuple_indexes4mut2_obns->emplace_back(i);
}
}
return Maybe<void>::Ok();
}
} // namespace
/* static */ Maybe<StatefulOpKernel> StatefulOpKernel::New(
const std::shared_ptr<OperatorConf>& op_conf, const Symbol<Stream>& stream,
const AttrMap& base_attrs, const std::shared_ptr<const ParallelDesc>& parallel_desc,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple) {
auto opkernel = std::shared_ptr<StatefulOpKernel>(new StatefulOpKernel());
opkernel->base_attrs_ = base_attrs;
opkernel->op_conf_ = op_conf;
opkernel->user_op_conf_.reset(new user_op::UserOpConfWrapper(op_conf));
opkernel->stream_ = stream;
opkernel->input_arg_tuple_ = input_arg_tuple;
opkernel->output_arg_tuple_ = output_arg_tuple;
opkernel->need_check_mem_case_ = true;
const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag()));
const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get();
opkernel->op_infer_ctx_helper_.reset(
new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper(
device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_,
opkernel->output_arg_tuple_));
opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper(
device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->reg_ctx_helper_.reset(
new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val);
if (op_reg_val->logical_tensor_desc_infer_fn) {
opkernel->tensor_desc_infer_fn_ = op_reg_val->logical_tensor_desc_infer_fn;
} else {
return Error::UnimplementedError();
}
opkernel->data_type_infer_fn_ = op_reg_val->data_type_infer_fn;
JUST(InitTensorTupleIndexes4Bns(
op_conf, input_arg_tuple->indexed_arg_name_and_index(),
output_arg_tuple->indexed_arg_name_and_index(), &opkernel->input_tuple_indexes4const_ibns_,
&opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_,
&opkernel->output_tuple_indexes4mut2_obns_));
return opkernel;
}
StatefulOpKernel::~StatefulOpKernel() = default;
size_t StatefulOpKernel::InferTmpSize(eager::CallContext* call_ctx,
const user_op::OpKernel* user_opkernel) const {
UserOpInferContext op_infer_ctx(op_infer_ctx_helper_.get(), call_ctx);
const auto& InferTmpSizeFn = GetInferTmpSizeFn(user_opkernel);
return InferTmpSizeFn(&op_infer_ctx);
}
Maybe<void> StatefulOpKernel::ChooseOpKernel(eager::CallContext* call_ctx,
const user_op::OpKernel** user_opkernel,
bool* need_temp_storage) {
OF_PROFILER_RANGE_GUARD("ChooseOpKernel");
DataType primary_dtype = kInvalidDataType;
const auto& inputs = call_ctx->inputs();
const auto& outputs = call_ctx->outputs();
if (likely(!inputs->empty())) {
primary_dtype = (*inputs)[0]->data_type();
} else if (likely(!outputs->empty())) {
primary_dtype = (*outputs)[0]->data_type();
} else {
// do nothing
}
UserKernelRegContext reg_ctx(reg_ctx_helper_.get(), call_ctx);
for (const auto& pair : dtype2cached_kernels_[primary_dtype]) {
if (likely(pair.first->is_matched_hob->get(reg_ctx))) {
*need_temp_storage = pair.first->need_temp_storage;
*user_opkernel = pair.second.get();
return Maybe<void>::Ok();
}
}
OF_PROFILER_RANGE_GUARD("fallback");
const auto& op_type_name = user_op_conf_->op_type_name();
const auto* kernel_reg_val =
JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult(op_type_name, reg_ctx));
CHECK_NOTNULL(kernel_reg_val);
auto* kernel = kernel_reg_val->create_fn();
dtype2cached_kernels_[primary_dtype].push_back(
{kernel_reg_val, std::shared_ptr<const user_op::OpKernel>(kernel)});
infer_tmp_size_fn_map_.emplace(kernel, &kernel_reg_val->infer_tmp_size_fn);
*need_temp_storage = kernel_reg_val->need_temp_storage;
*user_opkernel = kernel;
return Maybe<void>::Ok();
}
void StatefulOpKernel::TryInitOpKernelStateAndCache(eager::CallContext* call_ctx,
DeviceCtx* device_ctx,
const user_op::OpKernel* op_kernel,
user_op::OpKernelState** state,
user_op::OpKernelCache** cache) {
UserKernelInitAndCacheContext init_and_cache_ctx(init_and_cache_ctx_helper_.get(), call_ctx,
device_ctx);
if (state != nullptr) {
auto it = op_kernel_state_map_.find(op_kernel);
if (it != op_kernel_state_map_.end()) {
*state = it->second.get();
} else {
auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx);
op_kernel_state_map_.emplace(op_kernel, created_state);
*state = created_state.get();
}
}
{
auto& cache_in_map = op_kernel_cache_map_[op_kernel];
op_kernel->InitOpKernelCacheWithFlags(&init_and_cache_ctx,
user_op::OpKernelCache::kAllMayChanged, &cache_in_map);
*cache = cache_in_map.get();
}
}
const user_op::InferTmpSizeFn& StatefulOpKernel::GetInferTmpSizeFn(
const user_op::OpKernel* op_kernel) const {
return *infer_tmp_size_fn_map_.at(op_kernel);
}
user_op::TensorDescInferFn StatefulOpKernel::TensorDescInferFn() const {
return tensor_desc_infer_fn_;
}
user_op::DataTypeInferFn StatefulOpKernel::DataTypeInferFn() const { return data_type_infer_fn_; }
void StatefulOpKernel::Compute(eager::CallContext* call_ctx, DeviceCtx* device_ctx,
const user_op::OpKernel* user_opkernel,
user_op::OpKernelState* state,
const user_op::OpKernelCache* cache) const {
UserKernelComputeContext compute_context(compute_ctx_helper_.get(), call_ctx, device_ctx);
auto* compute_ctx = &compute_context;
OF_PROFILER_RANGE_GUARD("Compute");
if (Singleton<profiler::ProfileManager>::Get()) {
#if defined(WITH_CUDA) || defined(WITH_ROCM)
const auto CalMemorySize = [compute_ctx](const one::ArgVec& args) -> int64_t {
const auto Func = [compute_ctx](int64_t mem_size, const auto& pair) {
const auto tensor = compute_ctx->Tensor4ArgNameAndIndex(pair.first, pair.second);
return mem_size + tensor->shape_view().elem_cnt() * GetSizeOfDataType(tensor->data_type());
};
return std::accumulate(args.begin(), args.end(), static_cast<int64_t>(0), Func);
};
#endif
auto er_guard = CHECK_JUST(profiler::EventRecorder::CreateKernelEventRecorder(
op_type_name(),
#if defined(WITH_CUDA) || defined(WITH_ROCM)
[compute_ctx, CalMemorySize]() -> int64_t {
return CalMemorySize(compute_ctx->inputs()) + CalMemorySize(compute_ctx->outputs());
},
#endif
[compute_ctx]() -> std::vector<Shape> {
std::vector<Shape> shapes;
for (const auto& pair : compute_ctx->inputs()) {
shapes.emplace_back(
compute_ctx->TensorDesc4ArgNameAndIndex(pair.first, pair.second)->shape());
}
return shapes;
}));
user_opkernel->Compute(compute_ctx, state, cache);
} else {
user_opkernel->Compute(compute_ctx, state, cache);
}
}
} // namespace one
} // namespace oneflow
import numpy as np
import oneflow as flow
def fused_dot_feature_interaction(x,
y,
self_interaction=False,
output_padding=0,
output_concat=None,
dtype=flow.float32
):
# (bs, es) = x.shape
(bs, dims, es) = y.shape
if self_interaction:
offset = 1
else:
offset = 0
li = flow.tensor([i for i in range(dims + 1) for j in range(i + offset)])
lj = flow.tensor([j for i in range(dims + 1) for j in range(i + offset)])
T = flow.cat(
[
flow.reshape(x, (bs, 1, es)),
y,
],
dim=1,
)
Z = flow.matmul(T, T, transpose_b=True)
# gather_nd not support half, so cast to float32
Z = flow.cast(Z, flow.float32)
Zflat = Z[:, li, lj]
Zflat = flow.cast(Zflat, dtype)
if output_concat is not None:
R = flow.cat([output_concat, Zflat], dim=1)
else:
R = Zflat
if output_padding != 0:
padding_tensor = flow.tensor(
np.zeros((bs, output_padding)).astype(np.float32),
device="cuda",
requires_grad=False,
)
R = flow.cat([R, padding_tensor], dim=1)
return R
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import oneflow.unittest
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
import oneflow.profiler
from oneflow.profiler.events import CustomEvent, KernelEvent
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def get_event(events, name: str, input_shapes: str = "-"):
for item in events:
if isinstance(item, CustomEvent):
if item.name == name:
return item
if isinstance(item, KernelEvent):
if item.name == name and item.input_shapes == input_shapes:
return item
return None
def _test_lenet(
test_case,
on_cuda: bool,
record_shapes: bool,
record_bandwidth_for_cuda: bool = False,
):
x = flow.randn(2, 3, 32, 32)
lenet = LeNet()
if on_cuda:
x = x.to("cuda")
lenet.to("cuda")
activities = [oneflow.profiler.ProfilerActivity.CPU]
if on_cuda:
activities.append(oneflow.profiler.ProfilerActivity.CUDA)
with oneflow.profiler.profile(
activities=activities,
record_shapes=record_shapes,
record_bandwidth_for_cuda=record_bandwidth_for_cuda,
) as prof:
with oneflow.profiler.record_function("lenet_forward_total_time") as f:
for _ in range(2):
eager_res = lenet(x)
with oneflow.profiler.record_function("lenet_backward_total_time") as f:
eager_res.sum().backward()
events = prof.key_averages(group_by_input_shape=True)
conv_event = get_event(
events, "conv2d", "[(2,3,32,32), (6,3,5,5)]" if record_shapes else "-"
)
test_case.assertIsNotNone(conv_event)
if on_cuda:
test_case.assertGreater(conv_event.cpu_time, 0.0)
test_case.assertGreater(conv_event.cpu_time_total, 0.0)
test_case.assertGreater(conv_event.cuda_time, 0.0)
test_case.assertGreater(conv_event.cuda_time_total, 0.0)
else:
test_case.assertGreater(conv_event.cpu_time, 0.0)
test_case.assertGreater(conv_event.cpu_time_total, 0.0)
test_case.assertEqual(conv_event.count, 2 if record_shapes else 4)
if record_bandwidth_for_cuda and on_cuda:
test_case.assertNotEqual(conv_event.bandwidth, -1)
relu_grad_event = get_event(
events, "relu_grad", "[(2,6,28,28), (2,6,28,28)]" if record_shapes else "-"
)
test_case.assertIsNotNone(relu_grad_event)
if on_cuda:
test_case.assertGreater(relu_grad_event.cpu_time, 0.0)
test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)
test_case.assertGreater(relu_grad_event.cuda_time, 0.0)
test_case.assertGreater(relu_grad_event.cuda_time_total, 0.0)
else:
test_case.assertGreater(relu_grad_event.cpu_time, 0.0)
test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)
test_case.assertEqual(relu_grad_event.count, 1 if record_shapes else 4)
if record_bandwidth_for_cuda and on_cuda:
test_case.assertNotEqual(relu_grad_event.bandwidth, -1)
test_case.assertIsNotNone(get_event(events, "lenet_forward_total_time"))
test_case.assertIsNotNone(get_event(events, "lenet_backward_total_time"))
class TestProfileLenet(flow.unittest.TestCase):
def test_lenet_cpu(test_case):
_test_lenet(test_case, on_cuda=False, record_shapes=True)
_test_lenet(test_case, on_cuda=False, record_shapes=False)
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_lenet_cuda(test_case):
_test_lenet(
test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=False
)
_test_lenet(
test_case,
on_cuda=True,
record_shapes=False,
record_bandwidth_for_cuda=False,
)
_test_lenet(
test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=True
)
_test_lenet(
test_case, on_cuda=True, record_shapes=False, record_bandwidth_for_cuda=True
)
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import oneflow.unittest
import oneflow as flow
import oneflow.nn as nn
import oneflow.nn.functional as F
import oneflow.profiler
from oneflow.profiler.events import CustomEvent, KernelEvent
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def get_event(events, name: str, input_shapes: str = "-"):
for item in events:
if isinstance(item, CustomEvent):
if item.name == name:
return item
if isinstance(item, KernelEvent):
if item.name == name and item.input_shapes == input_shapes:
return item
return None
def _test_lenet(
test_case,
on_cuda: bool,
record_shapes: bool,
record_bandwidth_for_cuda: bool = False,
):
x = flow.randn(2, 3, 32, 32)
lenet = LeNet()
if on_cuda:
x = x.to("cuda")
lenet.to("cuda")
activities = [oneflow.profiler.ProfilerActivity.CPU]
if on_cuda:
activities.append(oneflow.profiler.ProfilerActivity.CUDA)
with oneflow.profiler.profile(
activities=activities,
record_shapes=record_shapes,
record_bandwidth_for_cuda=record_bandwidth_for_cuda,
) as prof:
with oneflow.profiler.record_function("lenet_forward_total_time") as f:
for _ in range(2):
eager_res = lenet(x)
with oneflow.profiler.record_function("lenet_backward_total_time") as f:
eager_res.sum().backward()
events = prof.key_averages(group_by_input_shape=True)
print(events)
conv_event = get_event(
events, "conv2d", "[(2,3,32,32), (6,3,5,5)]" if record_shapes else "-"
)
test_case.assertIsNotNone(conv_event)
if on_cuda:
test_case.assertGreater(conv_event.cpu_time, 0.0)
test_case.assertGreater(conv_event.cpu_time_total, 0.0)
test_case.assertGreater(conv_event.cuda_time, 0.0)
test_case.assertGreater(conv_event.cuda_time_total, 0.0)
else:
test_case.assertGreater(conv_event.cpu_time, 0.0)
test_case.assertGreater(conv_event.cpu_time_total, 0.0)
test_case.assertEqual(conv_event.count, 2 if record_shapes else 4)
if record_bandwidth_for_cuda and on_cuda:
test_case.assertNotEqual(conv_event.bandwidth, -1)
relu_grad_event = get_event(
events, "relu_grad", "[(2,6,28,28), (2,6,28,28)]" if record_shapes else "-"
)
test_case.assertIsNotNone(relu_grad_event)
if on_cuda:
test_case.assertGreater(relu_grad_event.cpu_time, 0.0)
test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)
test_case.assertGreater(relu_grad_event.cuda_time, 0.0)
test_case.assertGreater(relu_grad_event.cuda_time_total, 0.0)
else:
test_case.assertGreater(relu_grad_event.cpu_time, 0.0)
test_case.assertGreater(relu_grad_event.cpu_time_total, 0.0)
test_case.assertEqual(relu_grad_event.count, 1 if record_shapes else 4)
if record_bandwidth_for_cuda and on_cuda:
test_case.assertNotEqual(relu_grad_event.bandwidth, -1)
test_case.assertIsNotNone(get_event(events, "lenet_forward_total_time"))
test_case.assertIsNotNone(get_event(events, "lenet_backward_total_time"))
class TestProfileLenet(flow.unittest.TestCase):
def test_lenet_cpu(test_case):
_test_lenet(test_case, on_cuda=False, record_shapes=True)
_test_lenet(test_case, on_cuda=False, record_shapes=False)
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_lenet_cuda(test_case):
_test_lenet(
test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=False
)
_test_lenet(
test_case,
on_cuda=True,
record_shapes=False,
record_bandwidth_for_cuda=False,
)
_test_lenet(
test_case, on_cuda=True, record_shapes=True, record_bandwidth_for_cuda=True
)
_test_lenet(
test_case, on_cuda=True, record_shapes=False, record_bandwidth_for_cuda=True
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment