Commit a715222c authored by yuguo's avatar yuguo
Browse files

0.9.1-rocm

parent f262efc9
...@@ -20,7 +20,12 @@ limitations under the License. ...@@ -20,7 +20,12 @@ limitations under the License.
namespace oneflow { namespace oneflow {
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_WORKLOAD_ON_SCHEDULER_THREAD, false); DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_COMPUTE_ON_WORKER_THREAD, true);
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_STREAM_WAIT, true);
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_PENDING_HANDLE_WINDOW_SIZE, 10)
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_ENABLE_SCHEDULE_YIELD, true)
DEFINE_THREAD_LOCAL_ENV_INTEGER(ONEFLOW_VM_WORKER_THREAD_LIMIT, 16);
DEFINE_THREAD_LOCAL_ENV_BOOL(ONEFLOW_VM_MULTI_THREAD, true);
} } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_ #endif // ONEFLOW_CORE_COMMON_ENV_VAR_VM_H_
...@@ -14,15 +14,22 @@ See the License for the specific language governing permissions and ...@@ -14,15 +14,22 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include <stdexcept> #include <stdexcept>
#include "fmt/core.h"
#include "fmt/color.h"
#include "fmt/ostream.h"
#include "oneflow/core/common/error.h" #include "oneflow/core/common/error.h"
#include "oneflow/core/common/exception.h" #include "oneflow/core/common/exception.h"
#include "oneflow/core/common/protobuf.h" #include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
#include "oneflow/core/common/error_util.h" #include "oneflow/core/common/error_util.h"
#include "oneflow/core/common/env_var/debug_mode.h" #include "oneflow/core/common/env_var/debug_mode.h"
#include "oneflow/extension/stack/foreign_stack_getter.h"
#include "oneflow/core/thread/thread_manager.h"
namespace oneflow { namespace oneflow {
StackedError::StackedError() : stack_frame_(), error_proto_(new ErrorProto()) {}
namespace { namespace {
void LogError(const Error& error) { void LogError(const Error& error) {
...@@ -30,234 +37,220 @@ void LogError(const Error& error) { ...@@ -30,234 +37,220 @@ void LogError(const Error& error) {
LOG(ERROR) << error->msg(); LOG(ERROR) << error->msg();
} }
std::shared_ptr<ErrorProto>* MutThreadLocalError() { std::shared_ptr<StackedError>* MutThreadLocalError() {
thread_local std::shared_ptr<ErrorProto> error; thread_local std::shared_ptr<StackedError> error;
return &error; return &error;
} }
} // namespace } // namespace
Error&& Error::AddStackFrame(const std::string& file, const int64_t& line, Error&& Error::AddStackFrame(Symbol<ErrorStackFrame> error_stack_frame) {
const std::string& function) { stacked_error_->add_stack_frame(error_stack_frame);
auto* stack_frame = error_proto_->add_stack_frame();
stack_frame->set_file(file);
stack_frame->set_line(line);
stack_frame->set_function(function);
return std::move(*this); return std::move(*this);
} }
void Error::Merge(const Error& other) { void Error::Merge(const Error& other) {
std::string error_summary{error_proto_->error_summary()}; auto* error_proto = stacked_error_->mut_error_proto();
std::string msg{error_proto_->msg()}; error_proto->MergeFrom(*other.stacked_error_->error_proto());
error_proto_->MergeFrom(*other.error_proto_);
// MergeFrom will overwrite singular field, so restore it.
if (!error_summary.empty()) {
error_proto_->set_error_summary(error_summary + " " + error_proto_->error_summary());
}
if (!msg.empty()) { error_proto_->set_msg(msg + " " + error_proto_->msg()); }
} }
Error::operator std::string() const { return error_proto_->DebugString(); } Error::operator std::string() const { return stacked_error_->DebugString(); }
Error Error::Ok() { return std::make_shared<ErrorProto>(); } Error Error::Ok() { return std::make_shared<StackedError>(); }
Error Error::ProtoParseFailedError() { Error Error::ProtoParseFailedError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_proto_parse_failed_error(); error->mut_error_proto()->mutable_proto_parse_failed_error();
return error; return error;
} }
Error Error::JobSetEmptyError() { Error Error::JobSetEmptyError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_set_empty_error(); error->mut_error_proto()->mutable_job_set_empty_error();
return error; return error;
} }
Error Error::DeviceTagNotFoundError() { Error Error::DeviceTagNotFoundError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_device_tag_not_found_error(); error->mut_error_proto()->mutable_device_tag_not_found_error();
return error; return error;
} }
Error Error::InvalidValueError(const std::string& error_summary) { Error Error::InvalidValueError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->set_error_summary(error_summary); error->mut_error_proto()->mutable_invalid_value_error();
error->mutable_invalid_value_error();
return error; return error;
} }
Error Error::IndexError() { Error Error::IndexError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_index_error(); error->mut_error_proto()->mutable_index_error();
return error; return error;
} }
Error Error::TypeError() { Error Error::TypeError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_type_error(); error->mut_error_proto()->mutable_type_error();
return error; return error;
} }
Error Error::TimeoutError() { Error Error::TimeoutError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_timeout_error(); error->mut_error_proto()->mutable_timeout_error();
return error; return error;
} }
Error Error::JobNameExistError() { Error Error::JobNameExistError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_name_exist_error(); error->mut_error_proto()->mutable_job_name_exist_error();
return error; return error;
} }
Error Error::JobNameEmptyError() { Error Error::JobNameEmptyError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_name_empty_error(); error->mut_error_proto()->mutable_job_name_empty_error();
return error; return error;
} }
Error Error::JobNameNotEqualError() { Error Error::JobNameNotEqualError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_name_not_equal_error(); error->mut_error_proto()->mutable_job_name_not_equal_error();
return error; return error;
} }
Error Error::NoJobBuildAndInferCtxError() { Error Error::NoJobBuildAndInferCtxError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_no_job_build_and_infer_ctx_error(); error->mut_error_proto()->mutable_no_job_build_and_infer_ctx_error();
return error; return error;
} }
Error Error::JobConfFrozenError() { Error Error::JobConfFrozenError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_conf_frozen_error(); error->mut_error_proto()->mutable_job_conf_frozen_error();
return error; return error;
} }
Error Error::JobConfNotSetError() { Error Error::JobConfNotSetError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_conf_not_set_error(); error->mut_error_proto()->mutable_job_conf_not_set_error();
return error; return error;
} }
Error Error::JobConfRepeatedSetError() { Error Error::JobConfRepeatedSetError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_conf_repeated_set_error(); error->mut_error_proto()->mutable_job_conf_repeated_set_error();
return error; return error;
} }
Error Error::JobTypeNotSetError() { Error Error::JobTypeNotSetError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_job_type_not_set_error(); error->mut_error_proto()->mutable_job_type_not_set_error();
return error; return error;
} }
Error Error::LogicalBlobNameNotExistError() { Error Error::LogicalBlobNameNotExistError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_logical_blob_name_not_exist_error(); error->mut_error_proto()->mutable_logical_blob_name_not_exist_error();
return error; return error;
} }
Error Error::LogicalBlobNameExistError() { Error Error::LogicalBlobNameExistError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_logical_blob_name_exist_error(); error->mut_error_proto()->mutable_logical_blob_name_exist_error();
return error; return error;
} }
Error Error::LogicalBlobNameInvalidError() { Error Error::LogicalBlobNameInvalidError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_logical_blob_name_invalid_error(); error->mut_error_proto()->mutable_logical_blob_name_invalid_error();
return error; return error;
} }
Error Error::OpNameExistError() { Error Error::OpNameExistError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_op_name_exist_error(); error->mut_error_proto()->mutable_op_name_exist_error();
return error; return error;
} }
Error Error::OpConfDeviceTagNoSetError() { Error Error::OpConfDeviceTagNoSetError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_op_conf_device_tag_no_set_error(); error->mut_error_proto()->mutable_op_conf_device_tag_no_set_error();
return error; return error;
} }
Error Error::PlacementError() { Error Error::PlacementError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_placement_error(); error->mut_error_proto()->mutable_placement_error();
return error; return error;
} }
Error Error::BlobSplitAxisInferError() { Error Error::BlobSplitAxisInferError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_blob_split_axis_infer_error(); error->mut_error_proto()->mutable_blob_split_axis_infer_error();
return error; return error;
} }
Error Error::UnknownJobBuildAndInferError() { Error Error::UnknownJobBuildAndInferError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_unknown_job_build_and_infer_error(); error->mut_error_proto()->mutable_unknown_job_build_and_infer_error();
return error; return error;
} }
Error Error::CheckFailedError() { Error Error::CheckFailedError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_check_failed_error(); error->mut_error_proto()->mutable_check_failed_error();
return error; return error;
} }
Error Error::ValueNotFoundError() { Error Error::ValueNotFoundError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_value_not_found_error(); error->mut_error_proto()->mutable_value_not_found_error();
return error; return error;
} }
Error Error::TodoError() { Error Error::TodoError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_todo_error(); error->mut_error_proto()->mutable_todo_error();
return error; return error;
} }
Error Error::UnimplementedError() { Error Error::UnimplementedError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_unimplemented_error(); error->mut_error_proto()->mutable_unimplemented_error();
return error; return error;
} }
Error Error::RuntimeError() { Error Error::RuntimeError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_runtime_error(); error->mut_error_proto()->mutable_runtime_error();
return error; return error;
} }
Error Error::OutOfMemoryError() { Error Error::OutOfMemoryError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_out_of_memory_error(); error->mut_error_proto()->mutable_out_of_memory_error();
return error; return error;
} }
Error Error::BoxingNotSupportedError() { Error Error::BoxingNotSupportedError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_boxing_not_supported_error(); error->mut_error_proto()->mutable_boxing_not_supported_error();
return error; return error;
} }
Error Error::OpKernelNotFoundError(const std::string& error_summary, Error Error::OpKernelNotFoundError(const std::vector<std::string>& error_msgs) {
const std::vector<std::string>& error_msgs) { auto error = std::make_shared<StackedError>();
auto error = std::make_shared<ErrorProto>(); auto* op_kernel_not_found_error = error->mut_error_proto()->mutable_op_kernel_not_found_error();
error->set_error_summary(error_summary);
auto* op_kernel_not_found_error = error->mutable_op_kernel_not_found_error();
for (const auto& msg : error_msgs) { for (const auto& msg : error_msgs) {
op_kernel_not_found_error->add_op_kernels_not_found_debug_str(msg); op_kernel_not_found_error->add_op_kernels_not_found_debug_str(msg);
} }
return error; return error;
} }
Error Error::MultipleOpKernelsMatchedError(const std::string& error_summary, Error Error::MultipleOpKernelsMatchedError(const std::vector<std::string>& error_msgs) {
const std::vector<std::string>& error_msgs) { auto error = std::make_shared<StackedError>();
auto error = std::make_shared<ErrorProto>(); auto* multiple_op_kernels_matched_error =
error->set_error_summary(error_summary); error->mut_error_proto()->mutable_multiple_op_kernels_matched_error();
auto* multiple_op_kernels_matched_error = error->mutable_multiple_op_kernels_matched_error();
for (const auto& msg : error_msgs) { for (const auto& msg : error_msgs) {
multiple_op_kernels_matched_error->add_matched_op_kernels_debug_str(msg); multiple_op_kernels_matched_error->add_matched_op_kernels_debug_str(msg);
} }
...@@ -266,8 +259,9 @@ Error Error::MultipleOpKernelsMatchedError(const std::string& error_summary, ...@@ -266,8 +259,9 @@ Error Error::MultipleOpKernelsMatchedError(const std::string& error_summary,
Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc, Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,
uint64_t available, const std::string& device_tag) { uint64_t available, const std::string& device_tag) {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
auto* memory_zone_out_of_memory_error = error->mutable_memory_zone_out_of_memory_error(); auto* memory_zone_out_of_memory_error =
error->mut_error_proto()->mutable_memory_zone_out_of_memory_error();
memory_zone_out_of_memory_error->add_machine_id(std::to_string(machine_id)); memory_zone_out_of_memory_error->add_machine_id(std::to_string(machine_id));
memory_zone_out_of_memory_error->add_mem_zone_id(std::to_string(mem_zone_id)); memory_zone_out_of_memory_error->add_mem_zone_id(std::to_string(mem_zone_id));
memory_zone_out_of_memory_error->add_device_tag(device_tag); memory_zone_out_of_memory_error->add_device_tag(device_tag);
...@@ -276,79 +270,100 @@ Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, ...@@ -276,79 +270,100 @@ Error Error::MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id,
return error; return error;
} }
Error Error::LossBlobNotFoundError(const std::string& error_summary) { Error Error::LossBlobNotFoundError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_loss_blob_not_found_error(); error->mut_error_proto()->mutable_loss_blob_not_found_error();
error->set_error_summary(error_summary);
return error; return error;
} }
Error Error::RwMutexedObjectNotFoundError() { Error Error::RwMutexedObjectNotFoundError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_rw_mutexed_object_not_found_error(); error->mut_error_proto()->mutable_rw_mutexed_object_not_found_error();
return error; return error;
} }
Error Error::GradientFunctionNotFoundError() { Error Error::GradientFunctionNotFoundError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_gradient_function_not_found_error(); error->mut_error_proto()->mutable_gradient_function_not_found_error();
return error; return error;
} }
Error Error::SymbolIdUninitializedError() { Error Error::SymbolIdUninitializedError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_symbol_id_uninitialized_error(); error->mut_error_proto()->mutable_symbol_id_uninitialized_error();
return error; return error;
} }
Error Error::CompileOptionWrongError() { Error Error::CompileOptionWrongError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
error->mutable_compile_option_wrong_error(); error->mut_error_proto()->mutable_compile_option_wrong_error();
return error; return error;
} }
Error Error::InputDeviceNotMatchError() { Error Error::InputDeviceNotMatchError() {
auto error = std::make_shared<ErrorProto>(); auto error = std::make_shared<StackedError>();
auto* input_device_not_match_error = error->mutable_input_device_not_match_error(); auto* input_device_not_match_error =
error->mut_error_proto()->mutable_input_device_not_match_error();
input_device_not_match_error->add_info( input_device_not_match_error->add_info(
std::string("Input tensors are at different devices, please try to use tensor.to or " std::string("Input tensors are at different devices, please try to use tensor.to or "
"module.to to correct it.")); "module.to to correct it."));
return error; return error;
} }
std::string GetStackedErrorString(const std::shared_ptr<ErrorProto>& error) { std::string GetStackedErrorString(const std::shared_ptr<StackedError>& error) {
const auto& maybe_error = TRY(FormatErrorStr(error)); const auto& maybe_error = TRY(FormatErrorStr(error));
const auto& error_str = maybe_error.GetDataAndErrorProto(error->DebugString()); const auto& error_str = maybe_error.GetDataAndStackedError(error->DebugString());
CHECK_NE(error->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); CHECK_NE(error->error_proto()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);
return error_str.first; return error_str.first;
} }
std::string GetErrorString(const std::shared_ptr<ErrorProto>& error) { std::string GetErrorString(const std::shared_ptr<StackedError>& error) {
std::string error_str;
if (IsInDebugMode()) { if (IsInDebugMode()) {
return GetStackedErrorString(error); error_str = GetStackedErrorString(error);
} else { } else {
if (error->msg().empty() && error->stack_frame().size() > 0) { error_str = error->error_proto()->msg();
return error->stack_frame(0).error_msg(); }
} else { if (error_str.empty()) { error_str = "<No error message>"; }
return error->msg(); return error_str;
}
void ThrowError(const std::shared_ptr<StackedError>& error) {
std::string error_str;
fmt::format_to(std::back_inserter(error_str), "{}: {}\n",
fmt::styled("Error", fmt::emphasis::bold | fmt::fg(fmt::color::red)),
GetErrorString(error));
// Append foreign stack trace (e.g. Python stack trace) when it is available.
if (ForeignFrameThreadLocalGuard::Current().has_value()) {
auto frame = *CHECK_JUST(ForeignFrameThreadLocalGuard::Current());
if (!IsMainThread()) {
if (auto* stack_getter = Singleton<ForeignStackGetter>::Get()) {
fmt::format_to(std::back_inserter(error_str),
fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange),
"Related Python stack trace:\n");
fmt::format_to(std::back_inserter(error_str), "{}", stack_getter->GetFormattedStack(frame));
} else {
fmt::format_to(
std::back_inserter(error_str),
"You can set {} or {} to 1 to get the Python stack of the error.",
fmt::styled("ONEFLOW_DEBUG", fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)),
fmt::styled("ONEFLOW_PYTHON_STACK_GETTER",
fmt::emphasis::bold | fmt::fg(fmt::color::dark_orange)));
}
} }
} }
}
void ThrowError(const std::shared_ptr<ErrorProto>& error) {
*MutThreadLocalError() = error; *MutThreadLocalError() = error;
if (error->has_runtime_error()) { throw RuntimeException(GetErrorString(error)); } if ((*error)->has_runtime_error()) { throw RuntimeException(error_str); }
if (error->has_type_error()) { throw TypeException(GetErrorString(error)); } if ((*error)->has_type_error()) { throw TypeException(error_str); }
if (error->has_index_error()) { throw IndexException(GetErrorString(error)); } if ((*error)->has_index_error()) { throw IndexException(error_str); }
if (error->has_unimplemented_error()) { throw NotImplementedException(GetErrorString(error)); } if ((*error)->has_unimplemented_error()) { throw NotImplementedException(error_str); }
throw Exception(GetStackedErrorString(error)); throw Exception(GetStackedErrorString(error));
} }
const std::shared_ptr<ErrorProto>& ThreadLocalError() { return *MutThreadLocalError(); } const std::shared_ptr<StackedError>& ThreadLocalError() { return *MutThreadLocalError(); }
const char* kOfBugIssueUploadPrompt =
"This is a oneflow bug, please submit issues in "
"'https://github.com/Oneflow-Inc/oneflow/issues' include the log information of the error, the "
"minimum reproduction code, and the system information.";
const char* kOfBugIssueUploadPrompt = "This is a oneflow bug, please submit an issue at "
"'https://github.com/Oneflow-Inc/oneflow/issues' including "
"the log information of the error, the "
"minimum reproduction code, and the system information.";
} // namespace oneflow } // namespace oneflow
...@@ -18,32 +18,114 @@ limitations under the License. ...@@ -18,32 +18,114 @@ limitations under the License.
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <functional>
#include <glog/logging.h>
#include "oneflow/core/common/error.pb.h" #include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/small_vector.h"
#include "oneflow/core/common/hash.h"
namespace oneflow { namespace oneflow {
class ErrorStackFrame final {
public:
ErrorStackFrame(const ErrorStackFrame&) = default;
ErrorStackFrame(const std::string& file, int64_t line, const std::string& function)
: file_(file), line_(line), function_(function), code_text_() {}
ErrorStackFrame(const std::string& file, int64_t line, const std::string& function,
const std::string& code_text)
: file_(file), line_(line), function_(function), code_text_(code_text) {}
bool operator==(const ErrorStackFrame& other) const {
return this->file_ == other.file_ && this->line_ == other.line_
&& this->function_ == other.function_ && this->code_text_ == other.code_text_;
}
const std::string& file() const { return file_; }
int64_t line() const { return line_; }
const std::string& function() const { return function_; }
const std::string& code_text() const { return code_text_; }
std::string DebugString() const {
return file_ + ":" + std::to_string(line_) + " " + function_ + "\n\t" + code_text_ + "\n";
}
private:
std::string file_;
int64_t line_;
std::string function_;
std::string code_text_;
};
} // namespace oneflow
namespace std {
template<>
struct hash<::oneflow::ErrorStackFrame> final {
size_t operator()(const ::oneflow::ErrorStackFrame& frame) const {
using namespace oneflow;
return Hash(frame.file(), frame.line(), frame.function(), frame.code_text());
}
};
} // namespace std
namespace oneflow {
class StackedError final {
public:
StackedError();
StackedError(const StackedError&) = default;
constexpr static int kStackReservedSize = 16;
using FrameVector = small_vector<Symbol<ErrorStackFrame>, kStackReservedSize>;
const ErrorProto* operator->() const { return error_proto().get(); }
ErrorProto* operator->() { return mut_error_proto(); }
// Getters
const FrameVector& stack_frame() const { return stack_frame_; }
const std::shared_ptr<const ErrorProto>& error_proto() const { return error_proto_; }
std::string DebugString() const {
std::string str;
for (const auto& frame : stack_frame()) { str += frame->DebugString() + "\n"; }
str += error_proto()->DebugString();
return str;
}
// Setters
void add_stack_frame(Symbol<ErrorStackFrame> error_frame) { stack_frame_.push_back(error_frame); }
ErrorProto* mut_error_proto() { return const_cast<ErrorProto*>(error_proto_.get()); }
private:
FrameVector stack_frame_;
std::shared_ptr<const ErrorProto> error_proto_;
};
std::string GetErrorString(const std::shared_ptr<StackedError>& error);
class Error final { class Error final {
public: public:
Error(const std::shared_ptr<ErrorProto>& error_proto) : error_proto_(error_proto) {} Error(const std::shared_ptr<StackedError>& stacked_error)
: stacked_error_(stacked_error), msg_collecting_mode_(kMergeMessage) {}
Error(const Error&) = default; Error(const Error&) = default;
~Error() = default; ~Error() = default;
std::shared_ptr<ErrorProto> error_proto() const { return error_proto_; } std::shared_ptr<StackedError> stacked_error() const { return stacked_error_; }
const ErrorProto* operator->() const { return error_proto_.get(); } const ErrorProto* operator->() const { return stacked_error_->error_proto().get(); }
ErrorProto* operator->() { return error_proto_.get(); } ErrorProto* operator->() { return stacked_error_->mut_error_proto(); }
operator std::string() const; operator std::string() const;
void Assign(const Error& other) { error_proto_ = other.error_proto_; } void Assign(const Error& other) { stacked_error_ = other.stacked_error_; }
void Merge(const Error& other); void Merge(const Error& other);
// r-value reference is used to supporting expressions like `Error().AddStackFrame("foo.cpp", Error&& AddStackFrame(Symbol<ErrorStackFrame> error_stack_frame);
// ,"line", "Bar") << "invalid value"` because operator<<() need r-value reference
Error&& AddStackFrame(const std::string& file, const int64_t& line, const std::string& function);
static Error Ok(); static Error Ok();
static Error ProtoParseFailedError(); static Error ProtoParseFailedError();
static Error JobSetEmptyError(); static Error JobSetEmptyError();
static Error DeviceTagNotFoundError(); static Error DeviceTagNotFoundError();
static Error InvalidValueError(const std::string& error_summary); static Error InvalidValueError();
static Error IndexError(); static Error IndexError();
static Error TypeError(); static Error TypeError();
static Error TimeoutError(); static Error TimeoutError();
...@@ -72,11 +154,9 @@ class Error final { ...@@ -72,11 +154,9 @@ class Error final {
static Error BoxingNotSupportedError(); static Error BoxingNotSupportedError();
static Error MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc, static Error MemoryZoneOutOfMemoryError(int64_t machine_id, int64_t mem_zone_id, uint64_t calc,
uint64_t available, const std::string& device_type); uint64_t available, const std::string& device_type);
static Error OpKernelNotFoundError(const std::string& error_summary, static Error OpKernelNotFoundError(const std::vector<std::string>& error_msgs);
const std::vector<std::string>& error_msgs); static Error MultipleOpKernelsMatchedError(const std::vector<std::string>& error_msgs);
static Error MultipleOpKernelsMatchedError(const std::string& error_summary, static Error LossBlobNotFoundError();
const std::vector<std::string>& error_msgs);
static Error LossBlobNotFoundError(const std::string& error_summary);
static Error RwMutexedObjectNotFoundError(); static Error RwMutexedObjectNotFoundError();
...@@ -90,22 +170,39 @@ class Error final { ...@@ -90,22 +170,39 @@ class Error final {
static Error InputDeviceNotMatchError(); static Error InputDeviceNotMatchError();
enum MsgCollectingMode {
kInvalidMsgCollectingMode = 0,
kMergeMessage,
kOverrideThenMergeMessage,
};
MsgCollectingMode msg_collecting_mode() const { return msg_collecting_mode_; }
void set_msg_collecting_mode(MsgCollectingMode val) { msg_collecting_mode_ = val; }
private: private:
std::shared_ptr<ErrorProto> error_proto_; std::shared_ptr<StackedError> stacked_error_;
MsgCollectingMode msg_collecting_mode_;
}; };
void ThrowError(const std::shared_ptr<ErrorProto>& error); void ThrowError(const std::shared_ptr<StackedError>& error);
const std::shared_ptr<ErrorProto>& ThreadLocalError(); const std::shared_ptr<StackedError>& ThreadLocalError();
inline Error& operator<<(Error& error, Error::MsgCollectingMode mode) {
error.set_msg_collecting_mode(mode);
return error;
}
template<typename T> template<typename T>
Error& operator<<(Error& error, const T& x) { Error& operator<<(Error& error, const T& x) {
std::ostringstream ss; std::ostringstream ss;
ss << x; ss << x;
if (error->stack_frame().empty()) { if (error.msg_collecting_mode() == Error::kMergeMessage) {
error->set_msg(error->msg() + ss.str()); error->set_msg(error->msg() + ss.str());
} else if (error.msg_collecting_mode() == Error::kOverrideThenMergeMessage) {
error->set_msg(ss.str());
error.set_msg_collecting_mode(Error::kMergeMessage);
} else { } else {
auto* stack_frame_top = error->mutable_stack_frame(error->stack_frame_size() - 1); LOG(FATAL) << "UNIMPLEMENTED";
stack_frame_top->set_error_msg(stack_frame_top->error_msg() + ss.str());
} }
return error; return error;
} }
......
...@@ -119,13 +119,6 @@ message InputDeviceNotMatchError { ...@@ -119,13 +119,6 @@ message InputDeviceNotMatchError {
repeated string info = 1; repeated string info = 1;
} }
message ErrorStackFrame {
required string file = 1;
required int64 line = 2;
required string function = 3;
required string error_msg = 4;
}
message SymbolIdUninitializedError {} message SymbolIdUninitializedError {}
message InvalidValueError {} message InvalidValueError {}
...@@ -138,9 +131,8 @@ message TimeoutError {} ...@@ -138,9 +131,8 @@ message TimeoutError {}
message ValueNotFoundError {} message ValueNotFoundError {}
message ErrorProto { message ErrorProto {
optional string error_summary = 1 [default = ""]; optional string msg = 1 [default = ""];
optional string msg = 2 [default = ""]; optional string frame_msg = 2 [default = ""];
repeated ErrorStackFrame stack_frame = 3;
oneof error_type { oneof error_type {
ConfigAssertFailedError config_assert_failed_error = 12; ConfigAssertFailedError config_assert_failed_error = 12;
ConfigResourceUnavailableError config_resource_unavailable_error = 13; ConfigResourceUnavailableError config_resource_unavailable_error = 13;
......
...@@ -108,25 +108,18 @@ Maybe<std::string> FormatMsgOfStackFrame(std::string error_msg, bool is_last_sta ...@@ -108,25 +108,18 @@ Maybe<std::string> FormatMsgOfStackFrame(std::string error_msg, bool is_last_sta
return ss.str(); return ss.str();
} }
// the error_summary and msg in error proto
std::string FormatErrorSummaryAndMsgOfErrorProto(const std::shared_ptr<ErrorProto>& error) {
std::stringstream ss;
if (error->has_error_summary()) { ss << error->error_summary(); }
if (error->has_msg()) { ss << (ss.str().size() != 0 ? "\n" + error->msg() : error->msg()); }
return ss.str();
}
// the msg in error type instance. // the msg in error type instance.
Maybe<std::string> FormatMsgOfErrorType(const std::shared_ptr<ErrorProto>& error) { Maybe<std::string> FormatMsgOfErrorType(const std::shared_ptr<StackedError>& error) {
CHECK_NE_OR_RETURN(error->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET) const auto& error_proto = error->error_proto();
CHECK_NE_OR_RETURN(error_proto->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET)
<< Error::RuntimeError() << "Parse error failed, unknown error type"; << Error::RuntimeError() << "Parse error failed, unknown error type";
std::stringstream ss; std::stringstream ss;
const google::protobuf::Descriptor* error_des = error->GetDescriptor(); const google::protobuf::Descriptor* error_des = error_proto->GetDescriptor();
const google::protobuf::OneofDescriptor* oneof_field_des = const google::protobuf::OneofDescriptor* oneof_field_des =
error_des->FindOneofByName("error_type"); error_des->FindOneofByName("error_type");
const google::protobuf::Reflection* error_ref = error->GetReflection(); const google::protobuf::Reflection* error_ref = error_proto->GetReflection();
const google::protobuf::FieldDescriptor* field_des = const google::protobuf::FieldDescriptor* field_des =
error_ref->GetOneofFieldDescriptor(*error, oneof_field_des); error_ref->GetOneofFieldDescriptor(*error_proto, oneof_field_des);
CHECK_OR_RETURN(field_des != nullptr); CHECK_OR_RETURN(field_des != nullptr);
ss << "Error Type: " << field_des->full_name(); ss << "Error Type: " << field_des->full_name();
return ss.str(); return ss.str();
...@@ -134,20 +127,17 @@ Maybe<std::string> FormatMsgOfErrorType(const std::shared_ptr<ErrorProto>& error ...@@ -134,20 +127,17 @@ Maybe<std::string> FormatMsgOfErrorType(const std::shared_ptr<ErrorProto>& error
} // namespace } // namespace
Maybe<std::string> FormatErrorStr(const std::shared_ptr<ErrorProto>& error) { Maybe<std::string> FormatErrorStr(const std::shared_ptr<StackedError>& error) {
std::stringstream ss; std::stringstream ss;
ss << error->error_proto()->msg();
ss << error->error_proto()->frame_msg();
// Get msg from stack frame of error proto // Get msg from stack frame of error proto
for (auto stack_frame = error->mutable_stack_frame()->rbegin(); for (auto iter = error->stack_frame().rbegin(); iter < error->stack_frame().rend(); iter++) {
stack_frame < error->mutable_stack_frame()->rend(); stack_frame++) { auto stack_frame = *iter;
ss << FormatFileOfStackFrame(stack_frame->file()) << FormatLineOfStackFrame(stack_frame->line()) ss << FormatFileOfStackFrame(stack_frame->file()) << FormatLineOfStackFrame(stack_frame->line())
<< FormatFunctionOfStackFrame(stack_frame->function()) << FormatFunctionOfStackFrame(stack_frame->function())
<< *JUST(FormatMsgOfStackFrame(stack_frame->error_msg(), << *JUST(FormatMsgOfStackFrame(stack_frame->code_text(),
stack_frame == error->mutable_stack_frame()->rend() - 1)); iter == error->stack_frame().rend() - 1));
}
// Get msg from error summary and msg of error proto
std::string error_summary_and_msg_of_error_proto = FormatErrorSummaryAndMsgOfErrorProto(error);
if (error_summary_and_msg_of_error_proto.size() != 0) {
ss << "\n" << error_summary_and_msg_of_error_proto;
} }
// Get msg from error type of error proto // Get msg from error type of error proto
std::string msg_of_error_type = *JUST(FormatMsgOfErrorType(error)); std::string msg_of_error_type = *JUST(FormatMsgOfErrorType(error));
......
...@@ -22,7 +22,7 @@ limitations under the License. ...@@ -22,7 +22,7 @@ limitations under the License.
namespace oneflow { namespace oneflow {
Maybe<std::string> FormatErrorStr(const std::shared_ptr<ErrorProto>& error); Maybe<std::string> FormatErrorStr(const std::shared_ptr<StackedError>& error);
} // namespace oneflow } // namespace oneflow
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_HASH_H_
#define ONEFLOW_CORE_COMMON_HASH_H_
#include <functional>
namespace oneflow {
inline size_t HashCombine(size_t lhs, size_t rhs) {
return lhs ^ (rhs + 0x9e3779b9 + (lhs << 6U) + (lhs >> 2U));
}
inline void HashCombine(size_t* seed, size_t hash) { *seed = HashCombine(*seed, hash); }
template<typename... T>
inline void AddHash(size_t* seed, const T&... v) {
__attribute__((__unused__)) int dummy[] = {(HashCombine(seed, std::hash<T>()(v)), 0)...};
}
template<typename T, typename... Ts>
inline size_t Hash(const T& v1, const Ts&... vn) {
size_t seed = std::hash<T>()(v1);
AddHash<Ts...>(&seed, vn...);
return seed;
}
} // namespace oneflow
namespace std {
template<typename T0, typename T1>
struct hash<std::pair<T0, T1>> {
std::size_t operator()(const std::pair<T0, T1>& p) const {
return oneflow::Hash<T0, T1>(p.first, p.second);
}
};
template<typename T>
struct hash<std::vector<T>> {
std::size_t operator()(const std::vector<T>& vec) const {
std::size_t hash_value = vec.size();
for (const auto& elem : vec) { oneflow::AddHash<T>(&hash_value, elem); }
return hash_value;
}
};
} // namespace std
#endif // ONEFLOW_CORE_COMMON_HASH_H_
...@@ -17,9 +17,12 @@ limitations under the License. ...@@ -17,9 +17,12 @@ limitations under the License.
#ifndef ONEFLOW_CORE_COMMON_JUST_H_ #ifndef ONEFLOW_CORE_COMMON_JUST_H_
#define ONEFLOW_CORE_COMMON_JUST_H_ #define ONEFLOW_CORE_COMMON_JUST_H_
#include <sstream>
#include <glog/logging.h> #include <glog/logging.h>
#include <type_traits> #include <type_traits>
#include "oneflow/core/common/error.h" #include "oneflow/core/common/error.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/preprocessor.h" #include "oneflow/core/common/preprocessor.h"
namespace oneflow { namespace oneflow {
...@@ -30,29 +33,43 @@ class Maybe; ...@@ -30,29 +33,43 @@ class Maybe;
template<typename T> template<typename T>
class Optional; class Optional;
Maybe<std::string> FormatErrorStr(const std::shared_ptr<ErrorProto>&); Maybe<std::string> FormatErrorStr(const std::shared_ptr<StackedError>&);
namespace { namespace {
std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>&); std::string GetFormatedSerializedError(const std::shared_ptr<StackedError>&);
} }
namespace private_details { namespace private_details {
inline std::shared_ptr<ErrorProto>&& JustErrorAddStackFrame(std::shared_ptr<ErrorProto>&& err, inline std::shared_ptr<StackedError>&& JustErrorAddStackFrame(
const std::string& file, int64_t line, std::shared_ptr<StackedError>&& err, Symbol<ErrorStackFrame> error_stack_frame) {
const std::string& func, err->add_stack_frame(error_stack_frame);
const std::string& message) {
auto* stack_frame = err->add_stack_frame();
stack_frame->set_file(file);
stack_frame->set_line(line);
stack_frame->set_function(func);
stack_frame->set_error_msg(message);
return std::move(err); return std::move(err);
} }
template<typename T>
Error&& AddFrameMessage(Error&& error, const T& x) {
std::ostringstream ss;
ss << x;
error->set_frame_msg(error->frame_msg() + ss.str());
return std::move(error);
}
template<>
inline Error&& AddFrameMessage(Error&& error, const std::stringstream& x) {
AddFrameMessage(std::move(error), x.str());
return std::move(error);
}
template<>
inline Error&& AddFrameMessage(Error&& error, const std::ostream& x) {
AddFrameMessage(std::move(error), x.rdbuf());
return std::move(error);
}
template<typename... T> template<typename... T>
Error&& JustErrorAddMessage(Error&& err, T&&... msg) { Error&& JustErrorAddFrameMessage(Error&& err, T&&... msg) {
__attribute__((unused)) int dummy[] = {((void)(std::move(err) << std::forward<T>(msg)), 0)...}; __attribute__((unused)) int dummy[] = {
((void)(AddFrameMessage(std::move(err), std::forward<T>(msg))), 0)...};
return std::move(err); return std::move(err);
} }
...@@ -67,13 +84,13 @@ bool JustIsOk(const Optional<T>& val) { ...@@ -67,13 +84,13 @@ bool JustIsOk(const Optional<T>& val) {
} }
template<typename T> template<typename T>
std::shared_ptr<ErrorProto> JustGetError(const Maybe<T>& val) { std::shared_ptr<StackedError> JustGetError(const Maybe<T>& val) {
return val.error(); return val.stacked_error();
} }
template<typename T> template<typename T>
std::shared_ptr<ErrorProto> JustGetError(const Optional<T>&) { std::shared_ptr<StackedError> JustGetError(const Optional<T>&) {
return Error::ValueNotFoundError().error_proto(); return Error::ValueNotFoundError().stacked_error();
} }
template<typename T> template<typename T>
...@@ -91,55 +108,68 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo ...@@ -91,55 +108,68 @@ typename std::remove_const<typename std::remove_reference<T>::type>::type&& Remo
#if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__) #if defined(__GNUC__) || defined(__CUDACC__) || defined(__clang__)
#define JUST(...) \ #define JUST(...) \
::oneflow::private_details::RemoveRValConst(({ \ ::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddStackFrame( \ return ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \ ::oneflow::private_details::JustGetError(_just_value_to_check_), \
__FUNCTION__, OF_PP_STRINGIZE(__VA_ARGS__)); \ [](const char* function) { \
} \ thread_local static auto frame = ::oneflow::SymbolOf( \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, #__VA_ARGS__)); \
return frame; \
}(__FUNCTION__)); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST(...) \ #define CHECK_JUST(...) \
([&](const char* _just_closure_func_name_) { \ ([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \ auto&& _just_value_to_check_ = __JustStackCheckWrapper__(__VA_ARGS__); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ thread_local static auto frame = ::oneflow::SymbolOf( \
::oneflow::ErrorStackFrame(__FILE__, __LINE__, _just_closure_func_name_, #__VA_ARGS__)); \
THROW(RuntimeError) << ::oneflow::GetErrorString( \
::oneflow::private_details::JustErrorAddStackFrame( \ ::oneflow::private_details::JustErrorAddStackFrame( \
::oneflow::private_details::JustGetError(_just_value_to_check_), __FILE__, __LINE__, \ ::oneflow::private_details::JustGetError(_just_value_to_check_), frame)); \
_just_closure_func_name_, OF_PP_STRINGIZE(__VA_ARGS__))); \
} \ } \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \ return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \ })(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_MSG(value, ...) \ #define JUST_MSG(value, ...) \
::oneflow::private_details::RemoveRValConst(({ \ ::oneflow::private_details::RemoveRValConst(({ \
auto&& _just_value_to_check_ = (value); \ auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
return ::oneflow::private_details::JustErrorAddMessage( \ return ::oneflow::private_details::JustErrorAddFrameMessage( \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \ ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.AddStackFrame(__FILE__, __LINE__, __FUNCTION__), \ .AddStackFrame([](const char* function) { \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__); \ thread_local static auto frame = ::oneflow::SymbolOf( \
} \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, function, #value)); \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \ return frame; \
}(__FUNCTION__)), \
"\nError message from " __FILE__, ":", __LINE__, "\n\t", #value, ": ", __VA_ARGS__, \
"\n"); \
} \
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() })).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define CHECK_JUST_MSG(value, ...) \ #define CHECK_JUST_MSG(value, ...) \
([&](const char* _just_closure_func_name_) { \ ([&](const char* _just_closure_func_name_) { \
auto&& _just_value_to_check_ = (value); \ auto&& _just_value_to_check_ = (value); \
if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \ if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { \
LOG(FATAL) << ::oneflow::GetFormatedSerializedError( \ thread_local static auto frame = ::oneflow::SymbolOf( \
::oneflow::private_details::JustErrorAddMessage( \ ::oneflow::ErrorStackFrame(__FILE__, __LINE__, _just_closure_func_name_, #value)); \
::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \ THROW(RuntimeError) << ::oneflow::GetErrorString( \
.AddStackFrame(__FILE__, __LINE__, _just_closure_func_name_), \ ::oneflow::private_details::JustErrorAddFrameMessage( \
OF_PP_STRINGIZE(value), ": ", __VA_ARGS__) \ ::oneflow::Error(::oneflow::private_details::JustGetError(_just_value_to_check_)) \
.error_proto()); \ .AddStackFrame(frame), \
} \ "\nError message from " __FILE__, ":", __LINE__, "\n\t", #value, ": ", __VA_ARGS__, \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \ "\n") \
})(__FUNCTION__) \ .stacked_error()); \
} \
return std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); \
})(__FUNCTION__) \
.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() .Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()
#define JUST_OPT(...) \ #define JUST_OPT(...) \
......
...@@ -16,6 +16,8 @@ limitations under the License. ...@@ -16,6 +16,8 @@ limitations under the License.
#ifndef ONEFLOW_CORE_COMMON_MATH_UTIL_H_ #ifndef ONEFLOW_CORE_COMMON_MATH_UTIL_H_
#define ONEFLOW_CORE_COMMON_MATH_UTIL_H_ #define ONEFLOW_CORE_COMMON_MATH_UTIL_H_
#include <stdint.h> #include <stdint.h>
#include "data_type.h"
#include "oneflow/core/common/util.h"
namespace oneflow { namespace oneflow {
...@@ -23,6 +25,24 @@ int64_t Gcd(int64_t m, int64_t n); ...@@ -23,6 +25,24 @@ int64_t Gcd(int64_t m, int64_t n);
int64_t Lcm(int64_t m, int64_t n); int64_t Lcm(int64_t m, int64_t n);
template<typename T>
OF_DEVICE_FUNC T DeviceMin(T a, T b) {
#if defined(__CUDA_ARCH__)
return a < b ? a : b;
#else
return std::min(a, b);
#endif
}
template<typename T>
OF_DEVICE_FUNC T DeviceMax(T a, T b) {
#if defined(__CUDA_ARCH__)
return a > b ? a : b;
#else
return std::max(a, b);
#endif
}
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_MATH_UTIL_H_ #endif // ONEFLOW_CORE_COMMON_MATH_UTIL_H_
...@@ -44,10 +44,10 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala ...@@ -44,10 +44,10 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala
public: public:
Maybe(const T& data) : data_or_error_(std::make_shared<T>(data)) {} Maybe(const T& data) : data_or_error_(std::make_shared<T>(data)) {}
Maybe(T&& data) : data_or_error_(std::make_shared<T>(std::move(data))) {} Maybe(T&& data) : data_or_error_(std::make_shared<T>(std::move(data))) {}
Maybe(const Error& error) : data_or_error_(error.error_proto()) {} Maybe(const Error& error) : data_or_error_(error.stacked_error()) {}
Maybe(const std::shared_ptr<T>& data) : data_or_error_(data) {} Maybe(const std::shared_ptr<T>& data) : data_or_error_(data) {}
Maybe(std::shared_ptr<T>&& data) : data_or_error_(std::move(data)) {} Maybe(std::shared_ptr<T>&& data) : data_or_error_(std::move(data)) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : data_or_error_(error) {} Maybe(const std::shared_ptr<StackedError>& error) : data_or_error_(error) {}
Maybe(const Maybe&) = default; Maybe(const Maybe&) = default;
Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {} Maybe(Maybe&& other) : data_or_error_(std::move(other.data_or_error_)) {}
~Maybe() = default; ~Maybe() = default;
...@@ -56,65 +56,69 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala ...@@ -56,65 +56,69 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala
std::shared_ptr<T> Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const { std::shared_ptr<T> Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return data_or_error_.template Get<T>(); return data_or_error_.template Get<T>();
} }
std::shared_ptr<ErrorProto> error() const { return data_or_error_.template Get<ErrorProto>(); } std::shared_ptr<StackedError> stacked_error() const {
return data_or_error_.template Get<StackedError>();
}
std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }
std::string GetSerializedError() const { std::string GetSerializedError() const {
CHECK(!IsOk()); CHECK(!IsOk());
return GetFormatedSerializedError(this->error()); return GetFormatedSerializedError(this->stacked_error());
} }
template<typename Type = T> template<typename Type = T>
Type GetDataAndSerializedErrorProto(std::string* error_str, const Type& default_for_error) const { Type GetDataAndSerializedStackedError(std::string* error_str,
const Type& default_for_error) const {
static_assert(std::is_same<T, Type>::value, "error type for argument 1"); static_assert(std::is_same<T, Type>::value, "error type for argument 1");
if (IsOk()) { if (IsOk()) {
*error_str = ErrorProto().DebugString(); *error_str = StackedError().DebugString();
return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} else { } else {
*error_str = this->error()->DebugString(); *error_str = this->stacked_error()->DebugString();
return default_for_error; return default_for_error;
} }
} }
template<typename Type = T> template<typename Type = T>
std::pair<Type, std::shared_ptr<ErrorProto>> GetDataAndErrorProto( std::pair<Type, std::shared_ptr<StackedError>> GetDataAndStackedError(
const Type& default_for_error) const { const Type& default_for_error) const {
if (IsOk()) { if (IsOk()) {
return std::make_pair(*Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(), return std::make_pair(*Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::shared_ptr<ErrorProto>()); std::shared_ptr<StackedError>());
} else { } else {
return std::make_pair(default_for_error, error()); return std::make_pair(default_for_error, stacked_error());
} }
} }
std::pair<std::shared_ptr<T>, std::shared_ptr<ErrorProto>> GetDataPtrAndErrorProto() const { std::pair<std::shared_ptr<T>, std::shared_ptr<StackedError>> GetDataPtrAndStackedError() const {
if (IsOk()) { if (IsOk()) {
return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(), return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::shared_ptr<ErrorProto>()); std::shared_ptr<StackedError>());
} else { } else {
return std::make_pair(std::shared_ptr<T>(), error()); return std::make_pair(std::shared_ptr<T>(), stacked_error());
} }
} }
template<typename Type = T> template<typename Type = T>
Type GetOrThrow() const { Type GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); } if (!IsOk()) { ThrowError(stacked_error()); }
return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} }
std::shared_ptr<T> GetPtrOrThrow() const { std::shared_ptr<T> GetPtrOrThrow() const {
if (!IsOk()) { ThrowError(error()); } if (!IsOk()) { ThrowError(stacked_error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} }
private: private:
EitherPtr<T, ErrorProto> data_or_error_; EitherPtr<T, StackedError> data_or_error_;
}; };
template<typename T> template<typename T>
class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> final { class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> final {
public: public:
Maybe(const Error& error) : error_or_scalar_(error.error_proto()) { CheckError(); } Maybe(const Error& error) : error_or_scalar_(error.stacked_error()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : error_or_scalar_(error) { CheckError(); } Maybe(const std::shared_ptr<StackedError>& error) : error_or_scalar_(error) { CheckError(); }
Maybe(const Maybe&) = default; Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default; Maybe(Maybe&&) = default;
~Maybe() = default; ~Maybe() = default;
...@@ -123,31 +127,32 @@ class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> fina ...@@ -123,31 +127,32 @@ class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> fina
bool IsOk() const { return error_or_scalar_.IsScalar(); } bool IsOk() const { return error_or_scalar_.IsScalar(); }
void Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {} void Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {}
std::shared_ptr<ErrorProto> error() const { return error_or_scalar_.shared_ptr(); } std::shared_ptr<StackedError> stacked_error() const { return error_or_scalar_.shared_ptr(); }
std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }
std::string GetSerializedError() const { std::string GetSerializedError() const {
CHECK(!IsOk()); CHECK(!IsOk());
return GetFormatedSerializedError(this->error()); return GetFormatedSerializedError(this->stacked_error());
} }
void GetDataAndSerializedErrorProto(std::string* error_str) const { void GetDataAndSerializedStackedError(std::string* error_str) const {
if (IsOk()) { if (IsOk()) {
*error_str = ErrorProto().DebugString(); *error_str = StackedError().DebugString();
} else { } else {
*error_str = this->error()->DebugString(); *error_str = this->stacked_error()->DebugString();
} }
} }
std::shared_ptr<ErrorProto> GetDataAndErrorProto() const { std::shared_ptr<StackedError> GetDataAndStackedError() const {
if (IsOk()) { if (IsOk()) {
return std::shared_ptr<ErrorProto>(); return std::shared_ptr<StackedError>();
} else { } else {
return error(); return stacked_error();
} }
} }
void GetOrThrow() const { void GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); } if (!IsOk()) { ThrowError(stacked_error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} }
...@@ -157,12 +162,12 @@ class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> fina ...@@ -157,12 +162,12 @@ class Maybe<T, typename std::enable_if<std::is_same<T, void>::value>::type> fina
CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);
} }
SharedOrScalar<ErrorProto, void*> error_or_scalar_; SharedOrScalar<StackedError, void*> error_or_scalar_;
}; };
inline const std::shared_ptr<ErrorProto>& UninitializedValueError() { inline const std::shared_ptr<StackedError>& UninitializedValueError() {
static thread_local const auto& error = static thread_local const auto& error =
Error::InvalidValueError("uninitialized value").error_proto(); (Error::InvalidValueError() << "uninitialized value").stacked_error();
return error; return error;
} }
...@@ -170,8 +175,8 @@ template<typename T> ...@@ -170,8 +175,8 @@ template<typename T>
class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final { class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final {
public: public:
Maybe(T data) : error_or_scalar_(data) {} Maybe(T data) : error_or_scalar_(data) {}
Maybe(const Error& error) : error_or_scalar_(error.error_proto()) { CheckError(); } Maybe(const Error& error) : error_or_scalar_(error.stacked_error()) { CheckError(); }
Maybe(const std::shared_ptr<ErrorProto>& error) : error_or_scalar_(error) { CheckError(); } Maybe(const std::shared_ptr<StackedError>& error) : error_or_scalar_(error) { CheckError(); }
Maybe() : error_or_scalar_(UninitializedValueError()) {} Maybe() : error_or_scalar_(UninitializedValueError()) {}
Maybe(const Maybe&) = default; Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default; Maybe(Maybe&&) = default;
...@@ -183,34 +188,36 @@ class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final { ...@@ -183,34 +188,36 @@ class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final {
T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const { T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return error_or_scalar_.scalar_value(); return error_or_scalar_.scalar_value();
} }
std::shared_ptr<ErrorProto> error() const { return error_or_scalar_.shared_ptr(); } std::shared_ptr<StackedError> stacked_error() const { return error_or_scalar_.shared_ptr(); }
std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }
std::string GetSerializedError() const { std::string GetSerializedError() const {
CHECK(!IsOk()); CHECK(!IsOk());
return GetFormatedSerializedError(this->error()); return GetFormatedSerializedError(this->stacked_error());
} }
T GetDataAndSerializedErrorProto(std::string* error_str, const T& default_for_error) const { T GetDataAndSerializedStackedError(std::string* error_str, const T& default_for_error) const {
if (IsOk()) { if (IsOk()) {
*error_str = ErrorProto().DebugString(); *error_str = StackedError().DebugString();
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} else { } else {
*error_str = this->error()->DebugString(); *error_str = this->stacked_error()->DebugString();
return default_for_error; return default_for_error;
} }
} }
std::pair<T, std::shared_ptr<ErrorProto>> GetDataAndErrorProto(const T& default_for_error) const { std::pair<T, std::shared_ptr<StackedError>> GetDataAndStackedError(
const T& default_for_error) const {
if (IsOk()) { if (IsOk()) {
return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(), return std::make_pair(Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(),
std::shared_ptr<ErrorProto>()); std::shared_ptr<StackedError>());
} else { } else {
return std::make_pair(default_for_error, error()); return std::make_pair(default_for_error, stacked_error());
} }
} }
T GetOrThrow() const { T GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); } if (!IsOk()) { ThrowError(stacked_error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} }
...@@ -219,7 +226,7 @@ class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final { ...@@ -219,7 +226,7 @@ class Maybe<T, typename std::enable_if<IsScalarType<T>::value>::type> final {
CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET); CHECK_NE(this->error()->error_type_case(), ErrorProto::ERROR_TYPE_NOT_SET);
} }
SharedOrScalar<ErrorProto, T> error_or_scalar_; SharedOrScalar<StackedError, T> error_or_scalar_;
}; };
template<typename T> template<typename T>
...@@ -232,7 +239,7 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala ...@@ -232,7 +239,7 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala
public: public:
Maybe(T data) : maybe_ptr_(&data) {} Maybe(T data) : maybe_ptr_(&data) {}
Maybe(const Error& error) : maybe_ptr_(error) {} Maybe(const Error& error) : maybe_ptr_(error) {}
Maybe(const std::shared_ptr<ErrorProto>& error) : maybe_ptr_(error) {} Maybe(const std::shared_ptr<StackedError>& error) : maybe_ptr_(error) {}
Maybe(const Maybe&) = default; Maybe(const Maybe&) = default;
Maybe(Maybe&&) = default; Maybe(Maybe&&) = default;
~Maybe() = default; ~Maybe() = default;
...@@ -241,19 +248,20 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala ...@@ -241,19 +248,20 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala
T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const { T Data_YouAreNotAllowedToCallThisFuncOutsideThisFile() const {
return *maybe_ptr_.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return *maybe_ptr_.Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} }
std::shared_ptr<ErrorProto> error() const { return maybe_ptr_.error(); } std::shared_ptr<StackedError> stacked_error() const { return maybe_ptr_.stacked_error(); }
std::shared_ptr<const ErrorProto> error() const { return stacked_error()->error_proto(); }
std::string GetSerializedError() const { std::string GetSerializedError() const {
CHECK(!IsOk()); CHECK(!IsOk());
return maybe_ptr_.GetSerializedError(); return maybe_ptr_.GetSerializedError();
} }
T GetDataAndSerializedErrorProto(std::string* error_str) const { T GetDataAndSerializedStackedError(std::string* error_str) const {
return *maybe_ptr_.GetDataAndSerializedErrorProto(error_str, static_cast<PtrT>(nullptr)); return *maybe_ptr_.GetDataAndSerializedStackedError(error_str, static_cast<PtrT>(nullptr));
} }
T GetOrThrow() const { T GetOrThrow() const {
if (!IsOk()) { ThrowError(error()); } if (!IsOk()) { ThrowError(stacked_error()); }
return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); return Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
} }
...@@ -262,10 +270,10 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala ...@@ -262,10 +270,10 @@ class Maybe<T, typename std::enable_if<!(std::is_same<T, void>::value || IsScala
}; };
namespace { namespace {
std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>& error_proto) { std::string GetFormatedSerializedError(const std::shared_ptr<StackedError>& stacked_error) {
// return error msg got from formatted function or debugstring. // return error msg got from formatted function or debugstring.
const auto& maybe_error = TRY(FormatErrorStr(error_proto)); const auto& maybe_error = TRY(FormatErrorStr(stacked_error));
const auto& error_str = maybe_error.GetDataAndErrorProto(error_proto->DebugString()); const auto& error_str = maybe_error.GetDataAndStackedError(stacked_error->DebugString());
return error_str.first; return error_str.first;
} }
} // namespace } // namespace
...@@ -276,18 +284,32 @@ std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>& error_ ...@@ -276,18 +284,32 @@ std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>& error_
GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!maybe.IsOk());) \ GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!maybe.IsOk());) \
LOG(FATAL) << OF_PP_STRINGIZE(__VA_ARGS__) << " is not OK:\n" << maybe.GetSerializedError() LOG(FATAL) << OF_PP_STRINGIZE(__VA_ARGS__) << " is not OK:\n" << maybe.GetSerializedError()
#define OF_RETURN_IF_ERROR(...) \ #define OF_RETURN_IF_ERROR(...) \
for (auto&& maybe_##__LINE__ = __JustStackCheckWrapper__(__VA_ARGS__); \ for (auto&& maybe_##__LINE__ = __JustStackCheckWrapper__(__VA_ARGS__); \
!maybe_##__LINE__.IsOk();) \ !maybe_##__LINE__.IsOk();) \
return Error(maybe_##__LINE__.error()).AddStackFrame(__FILE__, __LINE__, __FUNCTION__) return Error(maybe_##__LINE__.stacked_error()).AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
#define OF_TODO() return Error::TodoError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) return frame; \
#define OF_UNIMPLEMENTED() \ }(__FUNCTION__))
return Error::UnimplementedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__)
#define OF_TODO() \
#define OF_RUNTIME_ERROR() \ return Error::TodoError().AddStackFrame([](const char* function) { \
return Error::RuntimeError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) << "RuntimeError " \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
": " return frame; \
}(__FUNCTION__))
#define OF_UNIMPLEMENTED() \
return Error::UnimplementedError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__))
#define OF_RUNTIME_ERROR() \
return Error::RuntimeError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "RuntimeError " \
": "
#define RETURN_ERROR_WITH_BUG_PROMPT() OF_RUNTIME_ERROR() << kOfBugIssueUploadPrompt #define RETURN_ERROR_WITH_BUG_PROMPT() OF_RUNTIME_ERROR() << kOfBugIssueUploadPrompt
#define OF_LOG_ONCE(x) \ #define OF_LOG_ONCE(x) \
...@@ -299,32 +321,51 @@ std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>& error_ ...@@ -299,32 +321,51 @@ std::string GetFormatedSerializedError(const std::shared_ptr<ErrorProto>& error_
} \ } \
} }
#define OF_COMPLIE_OPTION_ERROR() \ #define OF_COMPLIE_OPTION_ERROR() \
return Error::CompileOptionWrongError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \ return Error::CompileOptionWrongError().AddStackFrame([](const char* function) { \
thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
}(__FUNCTION__)) \
<< "Compile option wrong: " << "Compile option wrong: "
#define CHECK_OR_RETURN(expr) \ #define CHECK_OR_RETURN(expr) \
if (!(expr)) \ if (!(expr)) \
return Error::CheckFailedError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \ return Error::CheckFailedError().AddStackFrame([](const char* function) { \
<< "Check failed: " << OF_PP_STRINGIZE(expr) << " " thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
return frame; \
#define CHECK_EQ_OR_RETURN(lhs, rhs) \ }(__FUNCTION__)) \
CHECK_OR_RETURN((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " << "Check failed: " << OF_PP_STRINGIZE(expr) << " " << Error::kOverrideThenMergeMessage
#define CHECK_GE_OR_RETURN(lhs, rhs) \ #define CHECK_OR_RETURN_ERROR(expr) \
CHECK_OR_RETURN((lhs) >= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " if (!(expr)) \
return Error::CheckFailedError().AddStackFrame([](const char* function) { \
#define CHECK_GT_OR_RETURN(lhs, rhs) \ thread_local static auto frame = SymbolOf(ErrorStackFrame(__FILE__, __LINE__, function)); \
CHECK_OR_RETURN((lhs) > (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " return frame; \
}(__FUNCTION__))
#define CHECK_LE_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) <= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " #define CHECK_EQ_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " \
#define CHECK_LT_OR_RETURN(lhs, rhs) \ << Error::kOverrideThenMergeMessage
CHECK_OR_RETURN((lhs) < (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
#define CHECK_GE_OR_RETURN(lhs, rhs) \
#define CHECK_NE_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN((lhs) >= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " \
CHECK_OR_RETURN((lhs) != (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " << Error::kOverrideThenMergeMessage
#define CHECK_GT_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) > (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " \
<< Error::kOverrideThenMergeMessage
#define CHECK_LE_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) <= (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " \
<< Error::kOverrideThenMergeMessage
#define CHECK_LT_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) < (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " \
<< Error::kOverrideThenMergeMessage
#define CHECK_NE_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) != (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " \
<< Error::kOverrideThenMergeMessage
#define CHECK_STREQ_OR_RETURN(lhs, rhs) CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs)) #define CHECK_STREQ_OR_RETURN(lhs, rhs) CHECK_EQ_OR_RETURN(std::string(lhs), std::string(rhs))
......
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <gtest/gtest-death-test.h> #include <gtest/gtest-death-test.h>
#include <memory> #include <memory>
#include "oneflow/core/common/exception.h"
#include "oneflow/core/common/util.h" #include "oneflow/core/common/util.h"
namespace oneflow { namespace oneflow {
...@@ -24,7 +25,7 @@ namespace test { ...@@ -24,7 +25,7 @@ namespace test {
TEST(Maybe, JUST_MSG) { TEST(Maybe, JUST_MSG) {
auto f = [](int x) -> Maybe<int> { auto f = [](int x) -> Maybe<int> {
if (x > 10) { return Error::InvalidValueError("") << "input value " << x; } if (x > 10) { return Error::InvalidValueError() << "input value " << x; }
return 233; return 233;
}; };
...@@ -44,18 +45,22 @@ TEST(Maybe, JUST_MSG) { ...@@ -44,18 +45,22 @@ TEST(Maybe, JUST_MSG) {
auto data = CHECK_JUST(i(1)); auto data = CHECK_JUST(i(1));
ASSERT_EQ(data, 233); ASSERT_EQ(data, 233);
auto err = i(10.123).error(); auto err = i(10.123).stacked_error();
ASSERT_EQ(err->msg(), "input value 53"); ASSERT_EQ(err->error_proto()->msg(), R"(input value 53)");
ASSERT_EQ(err->stack_frame(0).error_msg(), "f(y): input value g(10)"); ASSERT_GE(err->stack_frame().size(), 2);
ASSERT_EQ(err->stack_frame(1).error_msg(), "h(y): input value int(10.123)"); ASSERT_EQ(err->stack_frame().at(0)->code_text(), "f(y)");
ASSERT_EQ(err->stack_frame().at(1)->code_text(), "h(y)");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto)
ASSERT_EXIT(CHECK_JUST(i(10.234)), testing::KilledBySignal(SIGABRT), R"(input value 53)"); try {
CHECK_JUST(i(10.234));
} catch (const RuntimeException& e) {
EXPECT_TRUE(std::string(e.what()).find(R"(input value 53)") != std::string::npos);
}
} }
TEST(Maybe, CHECK_OK) { TEST(Maybe, CHECK_OK) {
auto f = [](int x) -> Maybe<int> { auto f = [](int x) -> Maybe<int> {
if (x > 10) { return Error::InvalidValueError("") << "input value " << x; } if (x > 10) { return Error::InvalidValueError() << "input value " << x; }
return 233; return 233;
}; };
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/mem_util.h"
#include <unistd.h>
#include <sys/sysinfo.h>
namespace oneflow {
namespace {
struct ProcStat {
std::string pid, comm, state, ppid, pgrp, session, tty_nr;
std::string tpgid, flags, minflt, cminflt, majflt, cmajflt;
std::string utime, stime, cutime, cstime, priority, nice;
std::string num_threads, itrealvalue, starttime;
unsigned long vsize = 0;
long rss = 0;
};
} // namespace
// Reference: https://stackoverflow.com/questions/669438/how-to-get-memory-usage-at-runtime-using-c
void ProcessMemUsage(double* vm_usage, double* resident_set) {
*vm_usage = 0.0;
*resident_set = 0.0;
#ifdef __linux__
// 'file' stat seems to give the most reliable results
std::ifstream stat_stream("/proc/self/stat", std::ios_base::in);
ProcStat proc_stat;
stat_stream >> proc_stat.pid >> proc_stat.comm >> proc_stat.state >> proc_stat.ppid
>> proc_stat.pgrp >> proc_stat.session >> proc_stat.tty_nr >> proc_stat.tpgid
>> proc_stat.flags >> proc_stat.minflt >> proc_stat.cminflt >> proc_stat.majflt
>> proc_stat.cmajflt >> proc_stat.utime >> proc_stat.stime >> proc_stat.cutime
>> proc_stat.cstime >> proc_stat.priority >> proc_stat.nice >> proc_stat.num_threads
>> proc_stat.itrealvalue >> proc_stat.starttime >> proc_stat.vsize
>> proc_stat.rss; // don't care about the rest
stat_stream.close();
long page_size_kb = sysconf(_SC_PAGE_SIZE); // in case x86-64 is configured to use 2MB pages
// return with MB
*vm_usage = proc_stat.vsize >> 20;
// return with MB
*resident_set = (proc_stat.rss * page_size_kb) >> 20;
#endif // __linux__
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_MEM_UTIL_H_
#define ONEFLOW_CORE_COMMON_MEM_UTIL_H_
#include <chrono>
#include <sstream>
#include <string>
#include "oneflow/core/common/util.h"
namespace oneflow {
void ProcessMemUsage(double* vm_usage, double* resident_set);
} // namespace oneflow
#define LOG_MEM(...) \
double vm_ = 0, rss_ = 0; \
ProcessMemUsage(&vm_, &rss_); \
VLOG(1) << "File " __FILE__ << ", Line " << __LINE__ << ", Func " << __FUNCTION__ \
<< ", Mem size RSS " << rss_ << "MB."
#endif // ONEFLOW_CORE_COMMON_MEM_UTIL_H_
...@@ -24,7 +24,8 @@ namespace oneflow { ...@@ -24,7 +24,8 @@ namespace oneflow {
template<typename T, int N> template<typename T, int N>
class NdIndexOffsetHelper { class NdIndexOffsetHelper {
public: public:
NdIndexOffsetHelper() {} OF_DEVICE_FUNC NdIndexOffsetHelper() = default;
template<class... Ts> template<class... Ts>
OF_DEVICE_FUNC explicit NdIndexOffsetHelper(T d0, Ts... dims) { OF_DEVICE_FUNC explicit NdIndexOffsetHelper(T d0, Ts... dims) {
constexpr int n = 1 + sizeof...(dims); constexpr int n = 1 + sizeof...(dims);
...@@ -53,15 +54,14 @@ class NdIndexOffsetHelper { ...@@ -53,15 +54,14 @@ class NdIndexOffsetHelper {
InitStrides(dims_arr, n); InitStrides(dims_arr, n);
} }
~NdIndexOffsetHelper() = default; virtual ~NdIndexOffsetHelper() = default;
OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const { OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const {
T offset = 0; T offset = 0;
#ifdef __CUDA_ARCH__ #ifdef __CUDA_ARCH__
#pragma unroll #pragma unroll
#endif #endif
for (int i = 0; i < N - 1; ++i) { offset += index[i] * stride_[i]; } for (int i = 0; i < N; ++i) { offset += index[i] * stride_[i]; }
offset += index[N - 1];
return offset; return offset;
} }
...@@ -146,7 +146,7 @@ class NdIndexOffsetHelper { ...@@ -146,7 +146,7 @@ class NdIndexOffsetHelper {
OF_DEVICE_FUNC constexpr int Size() const { return N; } OF_DEVICE_FUNC constexpr int Size() const { return N; }
private: protected:
OF_DEVICE_FUNC void InitStrides(const T* dims, const int n) { OF_DEVICE_FUNC void InitStrides(const T* dims, const int n) {
for (int i = n - 1; i < N; ++i) { stride_[i] = 1; } for (int i = n - 1; i < N; ++i) { stride_[i] = 1; }
for (int i = n - 2; i >= 0; --i) { stride_[i] = dims[i + 1] * stride_[i + 1]; } for (int i = n - 2; i >= 0; --i) { stride_[i] = dims[i + 1] * stride_[i + 1]; }
...@@ -155,6 +155,36 @@ class NdIndexOffsetHelper { ...@@ -155,6 +155,36 @@ class NdIndexOffsetHelper {
T stride_[N]; T stride_[N];
}; };
template<typename T, int N>
class NdIndexStrideOffsetHelper : public NdIndexOffsetHelper<T, N> {
public:
OF_DEVICE_FUNC NdIndexStrideOffsetHelper() = default;
OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const T* strides) {
for (int i = 0; i < N; ++i) { stride_[i] = strides[i]; }
}
template<typename U>
OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const U* strides) {
for (int i = 0; i < N; ++i) { stride_[i] = static_cast<T>(strides[i]); }
}
OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const T* strides, int n) {
for (int i = 0; i < N; ++i) {
if (i < n) { stride_[i] = strides[i]; }
}
}
template<typename U>
OF_DEVICE_FUNC explicit NdIndexStrideOffsetHelper(const U* strides, int n) {
for (int i = 0; i < N; ++i) {
if (i < n) { stride_[i] = static_cast<T>(strides[i]); }
}
}
private:
using NdIndexOffsetHelper<T, N>::stride_;
};
} // namespace oneflow } // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_ #endif // ONEFLOW_CORE_COMMON_ND_INDEX_OFFSET_HELPER_H_
...@@ -19,6 +19,7 @@ limitations under the License. ...@@ -19,6 +19,7 @@ limitations under the License.
#include <sstream> #include <sstream>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#define private public #define private public
#define protected public
#include "oneflow/core/common/nd_index_offset_helper.h" #include "oneflow/core/common/nd_index_offset_helper.h"
namespace oneflow { namespace oneflow {
...@@ -142,6 +143,35 @@ TEST(NdIndexOffsetHelper, constructor) { ...@@ -142,6 +143,35 @@ TEST(NdIndexOffsetHelper, constructor) {
test_constructor<int64_t>(); test_constructor<int64_t>();
} }
template<typename T, typename U>
void test_stride_constructor() {
const T d1 = 5;
const T d2 = 6;
const U u1 = 5;
const U u2 = 6;
std::vector<T> strides({d1 * d2, d2, 1});
std::vector<U> strides_u({u1 * u2, u2, 1});
const NdIndexStrideOffsetHelper<T, 3> helper1(strides.data());
const NdIndexStrideOffsetHelper<T, 3> helper2(strides.data(), strides.size());
const NdIndexStrideOffsetHelper<T, 3> helper3(strides_u.data());
const NdIndexStrideOffsetHelper<T, 3> helper4(strides_u.data(), strides_u.size());
for (int i = 0; i < 3; i++) {
ASSERT_EQ(helper1.stride_[i], strides[i]);
ASSERT_EQ(helper2.stride_[i], strides[i]);
ASSERT_EQ(helper3.stride_[i], strides_u[i]);
ASSERT_EQ(helper4.stride_[i], strides_u[i]);
}
}
TEST(NdIndexStrideOffsetHelper, constructor) {
test_stride_constructor<int32_t, int64_t>();
test_stride_constructor<int64_t, int32_t>();
}
} // namespace test } // namespace test
} // namespace oneflow } // namespace oneflow
...@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and ...@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
*/ */
#include "oneflow/core/common/notifier.h" #include "oneflow/core/common/notifier.h"
#include "oneflow/core/common/foreign_lock_helper.h"
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow { namespace oneflow {
...@@ -37,6 +39,30 @@ NotifierStatus Notifier::WaitAndClearNotifiedCnt() { ...@@ -37,6 +39,30 @@ NotifierStatus Notifier::WaitAndClearNotifiedCnt() {
return kNotifierStatusSuccess; return kNotifierStatusSuccess;
} }
Maybe<void> Notifier::TimedWaitAndClearNotifiedCnt(size_t timeout_seconds) {
return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&, this]() -> Maybe<void> {
std::chrono::duration<size_t> seconds(timeout_seconds);
std::unique_lock<std::mutex> lock(mutex_);
CHECK_OR_RETURN(cond_.wait_for(lock, seconds, [this]() {
return notified_cnt_ > 0 || is_closed_;
})) << Error::TimeoutError();
CHECK_GT_OR_RETURN(notified_cnt_, 0) << "notifier closed.";
notified_cnt_ = 0;
return Maybe<void>::Ok();
});
}
Maybe<void> Notifier::TimedWaitAndClearNotifiedCnt(
const std::function<Maybe<bool>()>& StopWaitingAfterTimeout) {
while (true) {
auto status = TRY(TimedWaitAndClearNotifiedCnt(EnvInteger<ONEFLOW_TIMEOUT_SECONDS>()));
if (status.IsOk()) { return status; }
if (!status.error()->has_timeout_error()) { return status; }
if (JUST(StopWaitingAfterTimeout())) { return status; }
}
UNIMPLEMENTED_THEN_RETURN();
}
void Notifier::Close() { void Notifier::Close() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
is_closed_ = true; is_closed_ = true;
......
...@@ -32,6 +32,10 @@ class Notifier final { ...@@ -32,6 +32,10 @@ class Notifier final {
NotifierStatus WaitAndClearNotifiedCnt(); NotifierStatus WaitAndClearNotifiedCnt();
void Close(); void Close();
Maybe<void> TimedWaitAndClearNotifiedCnt(size_t timeout_seconds);
Maybe<void> TimedWaitAndClearNotifiedCnt(
const std::function<Maybe<bool>()>& StopWaitingAfterTimeout);
private: private:
size_t notified_cnt_; size_t notified_cnt_;
std::mutex mutex_; std::mutex mutex_;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_
#define ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_
namespace oneflow {
constexpr static int kOpArgsReservedSize = 4;
}
#endif // ONEFLOW_CORE_COMMON_OP_ARGS_RESERVED_SIZE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_
#define ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_
#include "oneflow/core/common/small_vector.h"
#include "oneflow/core/common/op_args_reserved_size.h"
namespace oneflow {
template<typename T>
using OpArgsVector = small_vector<T, kOpArgsReservedSize>;
}
#endif // ONEFLOW_CORE_COMMON_OP_ARGS_VECTOR_H_
...@@ -227,31 +227,31 @@ struct hash<oneflow::DataType> { ...@@ -227,31 +227,31 @@ struct hash<oneflow::DataType> {
template<> template<>
struct hash<oneflow::LogicalBlobId> { struct hash<oneflow::LogicalBlobId> {
size_t operator()(const oneflow::LogicalBlobId& lbi) const { size_t operator()(const oneflow::LogicalBlobId& lbi) const {
const auto& str_hash = std::hash<std::string>(); using namespace oneflow;
return str_hash(lbi.op_name()) ^ str_hash(lbi.blob_name()); return Hash(lbi.op_name(), lbi.blob_name());
} }
}; };
template<> template<>
struct hash<oneflow::OpBlobArg> { struct hash<oneflow::OpBlobArg> {
size_t operator()(const oneflow::OpBlobArg& oba) const { size_t operator()(const oneflow::OpBlobArg& oba) const {
const auto& str_hash = std::hash<std::string>(); using namespace oneflow;
return str_hash(oba.op_name()) ^ str_hash(oba.bn_in_op()); return Hash(oba.op_name(), oba.bn_in_op());
} }
}; };
template<> template<>
struct hash<oneflow::SbpParallel> { struct hash<oneflow::SbpParallel> {
size_t operator()(const oneflow::SbpParallel& sbp_parallel) const { size_t operator()(const oneflow::SbpParallel& sbp_parallel) const {
const auto& str_hash = std::hash<std::string>(); using namespace oneflow;
size_t ret = 0; size_t ret = 0;
if (sbp_parallel.has_broadcast_parallel()) { if (sbp_parallel.has_broadcast_parallel()) {
ret ^= str_hash("B"); AddHash(&ret, std::string("B"));
} else if (sbp_parallel.has_partial_sum_parallel()) { } else if (sbp_parallel.has_partial_sum_parallel()) {
ret ^= str_hash("P"); AddHash(&ret, std::string("P"));
} else if (sbp_parallel.has_split_parallel()) { } else if (sbp_parallel.has_split_parallel()) {
ret ^= str_hash("S"); AddHash(&ret, std::string("S"));
ret ^= std::hash<int64_t>()(sbp_parallel.split_parallel().axis()); AddHash(&ret, sbp_parallel.split_parallel().axis());
} else { } else {
UNIMPLEMENTED(); UNIMPLEMENTED();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment