Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_MLIR
#include "oneflow/ir/include/OneFlow/Extension.h"
#include "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h"
#include <glog/logging.h>
namespace oneflow {
REGISTER_JOB_PASS("IRRoundTripBeforeAD", IRRoundTrip<kBeforeAD>);
REGISTER_JOB_PASS("IRRoundTrip", IRRoundTrip<kAfterAD>);
} // namespace oneflow
#endif // WITH_MLIR
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_
#define ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
namespace oneflow {
inline Maybe<Job> GetCurrentJob() {
auto* job_ctx_mgr = Singleton<LazyJobBuildAndInferCtxMgr>::Get();
CHECK_NOTNULL_OR_RETURN(job_ctx_mgr);
auto* job_ctx =
JUST(job_ctx_mgr->FindJobBuildAndInferCtx(*JUST(job_ctx_mgr->GetCurrentJobName())));
CHECK_NOTNULL_OR_RETURN(job_ctx);
return job_ctx->job();
}
} // namespace oneflow
#endif // ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_COMMON_OFBLOB_H_
#define ONEFLOW_API_COMMON_OFBLOB_H_
#include "oneflow/core/common/just.h"
#include "oneflow/core/register/ofblob.h"
namespace oneflow {
template<typename T>
struct BlobBufferCopyUtil {
static Maybe<void> From(uint64_t of_blob_ptr, const T* buf_ptr, size_t size) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->AutoMemCopyFrom<T>(buf_ptr, size);
return Maybe<void>::Ok();
}
static Maybe<void> To(uint64_t of_blob_ptr, T* buf_ptr, size_t size) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->AutoMemCopyTo<T>(buf_ptr, size);
return Maybe<void>::Ok();
}
};
template<>
struct BlobBufferCopyUtil<void> {
static Maybe<void> From(uint64_t of_blob_ptr, const void* buf_ptr, size_t size) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->AutoMemCopyFrom<void>(buf_ptr, size);
return Maybe<void>::Ok();
}
static Maybe<void> To(uint64_t of_blob_ptr, void* buf_ptr, size_t size) {
auto* of_blob = reinterpret_cast<OfBlob*>(of_blob_ptr);
of_blob->AutoMemCopyTo<void>(buf_ptr, size);
return Maybe<void>::Ok();
}
};
} // namespace oneflow
#endif // !ONEFLOW_API_COMMON_OFBLOB_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_COMMON_SBP_H_
#define ONEFLOW_API_COMMON_SBP_H_
#include "oneflow/core/job/sbp_parallel.pb.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/maybe.h"
namespace oneflow {
namespace api {
inline Maybe<std::string> SbpToString(Symbol<SbpParallel> sbp_sym) {
std::string sbp_str = "oneflow.sbp.";
if (sbp_sym->has_broadcast_parallel()) {
sbp_str += "broadcast";
} else if (sbp_sym->has_partial_sum_parallel()) {
sbp_str += "partial_sum";
} else if (sbp_sym->has_split_parallel()) {
sbp_str += "split(dim=" + std::to_string(sbp_sym->split_parallel().axis()) + ")";
} else {
UNIMPLEMENTED_THEN_RETURN();
}
return sbp_str;
}
inline Maybe<std::string> NdSbpToString(Symbol<NdSbp> nd_sbp) {
std::string str = "(";
for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) {
if (i > 0) { str += ", "; }
str += *JUST(SbpToString(SymbolOf(nd_sbp->sbp_parallel(i))));
}
if (nd_sbp->sbp_parallel_size() == 1) { str += ","; }
str += ")";
return str;
}
} // namespace api
} // namespace oneflow
#endif // !ONEFLOW_API_COMMON_SBP_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_
#define ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/variable_tensor_mgr.h"
namespace oneflow {
inline Maybe<void> FillVariableTensorMgr(
const std::vector<std::string>& variable_op_names,
const std::vector<std::shared_ptr<one::Tensor>>& variable_tensors) {
auto mgr = Singleton<VariableTensorMgr>::Get();
return mgr->Fill(variable_op_names, variable_tensors);
}
inline void ClearVariableTensorMgr() {
auto mgr = Singleton<VariableTensorMgr>::Get();
mgr->Clear();
}
inline std::tuple<std::vector<std::string>, std::vector<std::shared_ptr<one::Tensor>>>
DumpVariableTensorMgr() {
auto mgr = Singleton<VariableTensorMgr>::Get();
return mgr->Dump();
}
} // namespace oneflow
#endif // ONEFLOW_API_COMMON_VARIABLE_TENSOR_MGR_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_API_H_
#define ONEFLOW_API_CPP_API_H_
#include "env.h"
#include "framework.h"
#include "nn.h"
#endif // !ONEFLOW_API_CPP_API_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include "oneflow/api/cpp/env.h"
#include "oneflow/api/cpp/env_impl.h"
#include "oneflow/core/framework/shut_down_util.h"
#include "oneflow/core/thread/thread_consistent_id.h"
namespace oneflow_api {
void initialize() {
if (of::Singleton<OneFlowEnv>::Get() == nullptr) { of::Singleton<OneFlowEnv>::New(); }
of::SetShuttingDown(false);
}
void release() {
if (of::Singleton<OneFlowEnv>::Get() != nullptr) { of::Singleton<OneFlowEnv>::Delete(); }
of::SetShuttingDown();
of::ResetThisThreadUniqueConsistentId().GetOrThrow();
}
} // namespace oneflow_api
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_ENV_H_
#define ONEFLOW_API_CPP_ENV_H_
namespace oneflow_api {
void initialize();
void release();
} // namespace oneflow_api
#endif // !ONEFLOW_API_CPP_ENV_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <glog/logging.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <random>
#include <type_traits>
#include "oneflow/api/cpp/env_impl.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/session_util.h"
#include "oneflow/core/job/env.pb.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/control/ctrl_bootstrap.h"
#include "oneflow/core/job/session.h"
#include "oneflow/core/rpc/include/base.h"
#include "oneflow/core/vm/vm_util.h"
namespace oneflow_api {
namespace of = oneflow;
namespace { // for inltialize
inline bool IsEnvInited() { return of::Singleton<of::EnvGlobalObjectsScope>::Get() != nullptr; }
bool HasEnvVar(const std::string& key) {
const char* value = getenv(key.c_str());
return value != nullptr;
}
std::string GetEnvVar(const std::string& key, const std::string& default_value) {
const char* value = getenv(key.c_str());
if (value == nullptr) { return default_value; }
return std::string(value);
}
int64_t GetEnvVar(const std::string& key, int64_t default_value) {
const char* value = getenv(key.c_str());
if (value == nullptr) { return default_value; }
return std::atoll(value);
}
int32_t FindFreePort(const std::string& addr) {
#ifdef __linux__
int sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
CHECK_GE(sock, 0) << "fail to find a free port.";
int optval = 1;
setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist(1, 1000);
int count = 0;
int num_attempts = 200;
do {
int port = 5000 + dist(rng);
struct sockaddr_in sockaddr {};
memset(&sockaddr, 0, sizeof(sockaddr));
sockaddr.sin_family = AF_INET;
sockaddr.sin_port = htons(port);
sockaddr.sin_addr.s_addr = inet_addr(addr.c_str());
int error = bind(sock, (struct sockaddr*)&sockaddr, sizeof(sockaddr));
if (error == 0) { return port; }
++count;
} while (count < num_attempts);
CHECK_NE(count, num_attempts) << "fail to find a free port.";
#endif // __linux__
return -1;
}
void CompleteEnvProto(of::EnvProto& env_proto) {
auto bootstrap_conf = env_proto.mutable_ctrl_bootstrap_conf();
auto master_addr = bootstrap_conf->mutable_master_addr();
const std::string addr = GetEnvVar("MASTER_ADDR", "127.0.0.1");
master_addr->set_host(addr);
master_addr->set_port(GetEnvVar("MASTER_PORT", FindFreePort(addr)));
bootstrap_conf->set_world_size(GetEnvVar("WORLD_SIZE", 1));
bootstrap_conf->set_rank(GetEnvVar("RANK", 0));
auto cpp_logging_conf = env_proto.mutable_cpp_logging_conf();
if (HasEnvVar("GLOG_log_dir")) { cpp_logging_conf->set_log_dir(GetEnvVar("GLOG_log_dir", "")); }
if (HasEnvVar("GLOG_logtostderr")) {
cpp_logging_conf->set_logtostderr(GetEnvVar("GLOG_logtostderr", -1));
}
if (HasEnvVar("GLOG_logbuflevel")) {
cpp_logging_conf->set_logbuflevel(GetEnvVar("GLOG_logbuflevel", -1));
}
}
} // namespace
OneFlowEnv::OneFlowEnv() {
of::EnvProto env_proto;
CompleteEnvProto(env_proto);
env_ctx_ = std::make_shared<of::EnvGlobalObjectsScope>(env_proto);
of::ConfigProto config_proto;
config_proto.mutable_resource()->set_cpu_device_num(1); // useless, will be set in TryInit
const int64_t session_id = of::NewSessionId();
CHECK_JUST(of::RegsiterSession(session_id));
config_proto.set_session_id(session_id);
session_ctx_ = std::make_shared<of::MultiClientSessionContext>(env_ctx_);
CHECK_JUST(session_ctx_->TryInit(config_proto));
}
OneFlowEnv::~OneFlowEnv() {
session_ctx_.reset();
env_ctx_.reset();
}
} // namespace oneflow_api
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <memory>
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#ifndef ONEFLOW_API_CPP_ENV_IMPL_H_
#define ONEFLOW_API_CPP_ENV_IMPL_H_
namespace oneflow_api {
namespace of = oneflow;
class OneFlowEnv {
public:
OF_DISALLOW_COPY(OneFlowEnv);
OneFlowEnv();
~OneFlowEnv();
std::shared_ptr<of::MultiClientSessionContext> GetSessionCtx() { return session_ctx_; }
private:
std::shared_ptr<of::EnvGlobalObjectsScope> env_ctx_;
std::shared_ptr<of::MultiClientSessionContext> session_ctx_;
};
} // namespace oneflow_api
#endif // ONEFLOW_API_CPP_ENV_IMPL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_FRAMEWORK_H_
#define ONEFLOW_API_CPP_FRAMEWORK_H_
#include "framework/device.h"
#include "framework/shape.h"
#include "framework/dtype.h"
#include "framework/tensor.h"
#include "framework/ivalue.h"
#include "framework/graph.h"
#endif // ONEFLOW_API_CPP_FRAMEWORK_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/cpp/framework/device.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/framework/device.h"
namespace oneflow_api {
namespace of = oneflow;
Device::Device(const std::string& type_or_type_with_device_id)
: device_(std::make_shared<of::Symbol<of::Device>>(
of::Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow())) {}
Device::Device(const std::string& type, int64_t device_id)
: device_(
std::make_shared<of::Symbol<of::Device>>(of::Device::New(type, device_id).GetOrThrow())) {}
const std::string& Device::type() const { return (*device_)->type(); }
int64_t Device::device_id() const { return (*device_)->device_id(); }
bool Device::operator==(const Device& rhs) const { return *device_ == *rhs.device_; }
bool Device::operator!=(const Device& rhs) const { return *device_ != *rhs.device_; }
} // namespace oneflow_api
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_
#define ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_
#include <string>
#include <memory>
namespace oneflow {
class Device;
template<typename T>
class Symbol;
} // namespace oneflow
namespace oneflow_api {
class Device final {
friend class Tensor;
friend class Graph;
public:
explicit Device(const std::string& type_or_type_with_device_id);
explicit Device(const std::string& type, int64_t device_id);
[[nodiscard]] const std::string& type() const;
[[nodiscard]] int64_t device_id() const;
[[nodiscard]] bool operator==(const Device& rhs) const;
[[nodiscard]] bool operator!=(const Device& rhs) const;
private:
std::shared_ptr<oneflow::Symbol<oneflow::Device>> device_ = nullptr;
};
} // namespace oneflow_api
#endif // !ONEFLOW_API_CPP_FRAMEWORK_DEVICE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/cpp/framework/dtype.h"
#include <map>
namespace oneflow_api {
namespace {
std::map<DType, int32_t> DTypeSize = {
{DType::kFloat, sizeof(float)}, {DType::kDouble, sizeof(double)},
{DType::kInt8, sizeof(int8_t)}, {DType::kInt32, sizeof(int32_t)},
{DType::kInt64, sizeof(int64_t)}, {DType::kBool, sizeof(bool)},
};
}
int32_t GetDTypeSize(DType dtype) { return DTypeSize[dtype]; }
} // namespace oneflow_api
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_
#define ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_
#include <cstdint>
namespace oneflow_api {
enum class DType {
kInvalidDataType = 0,
kChar = 1,
kFloat = 2,
kDouble = 3,
kInt8 = 4,
kInt32 = 5,
kInt64 = 6,
kUInt8 = 7,
kOFRecord = 8,
kFloat16 = 9,
kTensorBuffer = 10,
kBFloat16 = 11,
kBool = 12,
kMaxDataType = 13
};
[[nodiscard]] int32_t GetDTypeSize(DType dtype);
} // namespace oneflow_api
#endif // ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/common/ofblob.h"
#include "oneflow/api/common/variable_tensor_mgr.h"
#include "oneflow/api/cpp/env_impl.h"
#include "oneflow/api/cpp/framework/device.h"
#include "oneflow/api/cpp/framework/dtype.h"
#include "oneflow/api/cpp/framework/graph.h"
#include "oneflow/api/cpp/framework/ivalue.h"
#include "oneflow/api/cpp/framework/shape.h"
#include "oneflow/api/cpp/framework/tensor.h"
#include "oneflow/api/common/job_build_and_infer_ctx.h"
#include "oneflow/api/python/job_build/job_build_and_infer.h"
#include "oneflow/core/common/data_type.pb.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/common/hash_container.h"
#include "oneflow/core/common/just.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/symbol.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/device.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/framework/multi_client_session_context.h"
#include "oneflow/core/framework/nn_graph.h"
#include "oneflow/core/framework/scope_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/tensor_util.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/job/job.pb.h"
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/job_conf.pb.h"
#include "oneflow/core/job/job_ir.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/session.h"
#include "oneflow/core/operator/interface_blob_conf.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/register/logical_blob_id.pb.h"
#include "oneflow/core/vm/vm_util.h"
namespace oneflow_api {
namespace of = oneflow;
namespace {
class CompileScope {
public:
CompileScope(const of::JobConfigProto& job_config, const of::Device& device) {
of::JobConfigProto mut_job_config = job_config;
const std::shared_ptr<of::Scope> scope = CHECK_JUST(MakeScope(mut_job_config, device));
CHECK_JUST(of::ThreadLocalScopeStackPush(scope));
CHECK_JUST(of::JobBuildAndInferCtx_Open(mut_job_config.job_name()));
CHECK_JUST(CHECK_JUST(of::GetCurInferCtx())->SetJobConf(mut_job_config));
}
~CompileScope() {
CHECK_JUST(of::JobBuildAndInferCtx_Close());
CHECK_JUST(of::ThreadLocalScopeStackPop());
}
private:
of::LazyMode::Guard lazy_mode_enabled_guard{true};
};
std::shared_ptr<of::one::TensorTuple> ConvertToTensorTuple(
const std::vector<std::shared_ptr<of::one::Tensor>>& tensors) {
auto tensor_tuple = std::make_shared<of::one::TensorTuple>();
for (const auto& tensor : tensors) { tensor_tuple->emplace_back(tensor); }
return tensor_tuple;
}
std::string GetDeviceTag(const Device& device) { return device.type(); }
template<class T1, class T2>
const std::pair<std::vector<T1>, std::vector<T2>> Unzip(const of::HashMap<T1, T2>& hash_map) {
std::vector<T1> vec1;
std::vector<T2> vec2;
for (const auto& entry : hash_map) {
vec1.emplace_back(entry.first);
vec2.emplace_back(entry.second);
}
return std::make_pair(vec1, vec2);
}
Shape OfShapeToOfApiShape(const of::Shape& of_shape) {
std::vector<int64_t> dims(of_shape.dim_vec().begin(), of_shape.dim_vec().end());
return Shape(dims);
}
} // namespace
class Graph::GraphImpl final {
public:
explicit GraphImpl(const std::string& model_path, const Device& device = Device("cpu"));
GraphImpl(const GraphImpl& graph) = delete;
GraphImpl(GraphImpl&& graph) = default;
~GraphImpl();
GraphImpl& operator=(const GraphImpl& graph) = delete;
GraphImpl& operator=(GraphImpl&& graph) = default;
InputOutputInfos GetInputInfos();
InputOutputInfos GetOutputInfos();
std::vector<Tensor> Forward(const std::vector<Tensor>& inputs);
void set_batch_size(int batch_size) { batch_size_ = batch_size; }
of::Maybe<void> RegisterJobPass(
const std::function<std::string(const std::string& job)>& pass_fn);
private:
of::Maybe<void> CollectInputOutputInfos();
of::Maybe<void> Compile(const std::vector<Tensor>& inputs);
of::Maybe<std::vector<Tensor>> Run(const std::vector<Tensor>& inputs) const;
of::Maybe<void> AddOp(of::OperatorConf op_conf);
of::Maybe<void> BuildGraph();
of::Maybe<void> LoadCheckpoint();
of::Maybe<void> RegisterTensors(const std::vector<Tensor>& inputs);
of::Maybe<of::Job> ApplyJobPasses(const of::Job& job);
std::shared_ptr<of::NNGraph> graph_ = nullptr;
std::string model_path_;
bool is_compiled_ = false;
int batch_size_ = 0;
Device device_;
of::Job job_;
InputOutputInfos input_infos_;
InputOutputInfos output_infos_;
of::HashMap<std::string, std::shared_ptr<of::one::Tensor>> output_name_to_tensor_;
of::HashMap<std::string, std::shared_ptr<of::one::Tensor>> variable_op_name_to_tensor_;
std::shared_ptr<of::one::TensorTuple> output_tensor_tuple_;
std::shared_ptr<of::one::TensorTuple> parameter_tensor_tuple_;
std::vector<std::function<std::string(const std::string&)>> registered_job_passes_;
};
Graph::Graph(const std::string& model_path, const Device& device)
: graph_(std::make_unique<GraphImpl>(model_path, device)) {}
Graph::~Graph() = default;
Graph::Graph(Graph&& graph) noexcept : graph_(std::move(graph.graph_)) {}
Graph& Graph::operator=(Graph&& graph) noexcept {
if (&graph == this) { return *this; }
graph_ = std::move(graph.graph_);
return *this;
}
InputOutputInfos Graph::GetInputInfos() { return graph_->GetInputInfos(); }
InputOutputInfos Graph::GetOutputInfos() { return graph_->GetOutputInfos(); }
void Graph::RegisterJobPass(const std::function<std::string(const std::string& job)>& pass_fn) {
CHECK_JUST(graph_->RegisterJobPass(pass_fn));
}
IValue Graph::Forward(const IValue& inputs) {
std::vector<Tensor> input_tensors;
if (inputs.IsNone()) {
// do nothing
} else if (inputs.IsTensor()) {
input_tensors.emplace_back(inputs.ToTensor());
} else if (inputs.IsTensorVector()) {
input_tensors = inputs.ToTensorVector();
} else {
LOG(WARNING) << "Graph currently only support types: Tensor/vector(Tensor)/None";
}
std::vector<Tensor> output_tensors = graph_->Forward(input_tensors);
if (output_tensors.empty()) {
return IValue{};
} else if (output_tensors.size() == 1) {
return IValue(output_tensors.at(0));
} else {
return IValue(output_tensors);
}
}
void Graph::set_batch_size(int batch_size) { graph_->set_batch_size(batch_size); }
Graph Graph::Load(const std::string& model_path, const Device& device) {
Graph graph(model_path, device);
return graph;
}
Graph::GraphImpl::GraphImpl(const std::string& model_path, const Device& device)
: model_path_(model_path), device_(device) {
CHECK_JUST(of::LoadJobFromIR(&job_, model_path + "/model.mlir"));
CollectInputOutputInfos();
if (of::ParseBooleanFromEnv("ONEFLOW_SERVING_DEBUG", false)) { LOG(ERROR) << job_.DebugString(); }
job_.mutable_job_conf()->mutable_predict_conf();
job_.mutable_job_conf()->set_job_name(job_.mutable_job_conf()->job_name() + of::NewUniqueId());
}
InputOutputInfos Graph::GraphImpl::GetInputInfos() { return input_infos_; }
InputOutputInfos Graph::GraphImpl::GetOutputInfos() { return output_infos_; }
of::Maybe<void> Graph::GraphImpl::CollectInputOutputInfos() {
const of::OpGraph op_graph(job_);
size_t input_order = 0;
size_t output_order = 0;
op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe<void> {
const of::OperatorConf& op_conf = node->op().op_conf();
if (op_conf.has_input_conf()) {
of::InterfaceBlobConf blob_conf = op_conf.input_conf().blob_conf();
input_infos_[op_conf.name()] =
InputOutputAttribute(static_cast<DType>(blob_conf.data_type()),
OfShapeToOfApiShape(of::Shape(blob_conf.shape())), input_order);
input_order += 1;
} else if (op_conf.has_output_conf()) {
of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf();
output_infos_[op_conf.name()] =
InputOutputAttribute(static_cast<DType>(blob_conf.data_type()),
OfShapeToOfApiShape(of::Shape(blob_conf.shape())), output_order);
output_order += 1;
}
return of::Maybe<void>::Ok();
});
return of::Maybe<void>::Ok();
}
of::Maybe<void> Graph::GraphImpl::RegisterJobPass(
const std::function<std::string(const std::string& job)>& pass_fn) {
if (is_compiled_) {
return of::Error::RuntimeError() << "job pass should be registered before compile and forward";
}
registered_job_passes_.emplace_back(pass_fn);
return of::Maybe<void>::Ok();
}
of::Maybe<of::Job> Graph::GraphImpl::ApplyJobPasses(const of::Job& job) {
auto current_job = std::make_shared<of::Job>(job);
for (const auto& pass_fn : registered_job_passes_) {
std::string new_serialized_job = pass_fn(current_job->SerializeAsString());
of::Job new_job;
if (!new_job.ParseFromString(new_serialized_job)) {
return of::Error::RuntimeError() << "invalid serialized job after pass applied";
}
current_job->Swap(&new_job);
}
return current_job;
}
std::vector<Tensor> Graph::GraphImpl::Forward(const std::vector<Tensor>& inputs) {
if (!is_compiled_) {
static std::mutex mtx;
std::lock_guard<std::mutex> lock(mtx);
Compile(inputs).GetOrThrow();
is_compiled_ = true;
}
return Run(inputs).GetOrThrow();
}
of::Maybe<void> Graph::GraphImpl::Compile(const std::vector<Tensor>& inputs) {
JUST(BuildGraph());
JUST(RegisterTensors(inputs));
JUST(graph_->CompileAndInitRuntime());
return of::Maybe<void>::Ok();
}
of::Maybe<std::vector<Tensor>> Graph::GraphImpl::Run(const std::vector<Tensor>& inputs) const {
const auto input_tensor_tuple = std::make_shared<of::one::TensorTuple>();
for (const auto& tensor : inputs) { input_tensor_tuple->emplace_back(tensor.tensor_); }
JUST(of::RunLazyNNGraph(*input_tensor_tuple, *output_tensor_tuple_, *parameter_tensor_tuple_,
graph_));
JUST(of::SoftSyncNNGraphBuffers(*output_tensor_tuple_, graph_));
std::vector<Tensor> outputs;
for (const auto& tensor : *output_tensor_tuple_) { outputs.emplace_back(Tensor(tensor)); }
return outputs;
}
of::Maybe<void> Graph::GraphImpl::AddOp(of::OperatorConf op_conf) {
{
const std::shared_ptr<of::Scope> scope = JUST(of::GetCurrentScope());
op_conf.set_scope_symbol_id(scope->symbol_id().value_or(0));
}
op_conf.set_device_tag(GetDeviceTag(device_));
if (batch_size_ > 0 && op_conf.has_input_conf()) {
op_conf.mutable_input_conf()->mutable_blob_conf()->mutable_shape()->mutable_dim()->Set(
0, batch_size_);
}
auto* ctx = JUST(of::GetCurInferCtx());
JUST(ctx->AddAndInferConsistentOp(op_conf));
return of::Maybe<void>::Ok();
}
of::Maybe<void> Graph::GraphImpl::BuildGraph() {
CompileScope build_graph_scope(job_.job_conf(), *device_.device_->shared_from_symbol());
{
const of::OpGraph op_graph(job_);
op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe<void> {
const of::OperatorConf& op_conf = node->op().op_conf();
JUST(AddOp(op_conf));
if (op_conf.has_variable_conf()) {
const of::LazyMode::Guard lazy_mode_disabled_guard{false};
const of::VariableOpConf& variable_conf = op_conf.variable_conf();
variable_op_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty(
of::Shape(variable_conf.shape()),
JUST(of::DType::Get(static_cast<of::DataType>(variable_conf.data_type()))),
*device_.device_, /*pin_memory=*/false));
}
return of::Maybe<void>::Ok();
});
}
JUST(LoadCheckpoint());
JUST(of::CurJobBuildAndInferCtx_Complete());
std::shared_ptr<of::Job> complete_job = JUST(of::GetCurrentJob());
int64_t job_id = JUST(of::JobBuildAndInferCtx_GetCurrentJobId());
CHECK(of::Singleton<OneFlowEnv>::Get() != nullptr);
// apply custom job passes
complete_job = JUST(ApplyJobPasses(*complete_job));
graph_ = std::make_shared<of::NNGraph>(job_.job_conf().job_name(), *complete_job, job_id,
of::Singleton<OneFlowEnv>::Get()->GetSessionCtx());
{
const of::OpGraph complete_graph(*complete_job);
complete_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe<void> {
const of::LazyMode::Guard lazy_mode_disabled_guard{false};
const of::OperatorConf& op_conf = node->op().op_conf();
if (op_conf.has_output_conf()) {
of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf();
if (batch_size_ > 0) {
const std::string input_lbi_str = op_conf.output_conf().in();
const of::LogicalBlobId input_lbi = of::GenLogicalBlobId(input_lbi_str);
int64_t batch_size = node->LogicalBlobDesc4Lbi(input_lbi).shape().At(0);
blob_conf.mutable_shape()->set_dim(0, batch_size);
}
output_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty(
of::Shape(blob_conf.shape()),
JUST(of::DType::Get(static_cast<of::DataType>(blob_conf.data_type()))),
*device_.device_, /*pin_memory=*/false));
}
return of::Maybe<void>::Ok();
});
}
return of::Maybe<void>::Ok();
}
of::Maybe<void> Graph::GraphImpl::LoadCheckpoint() {
for (const auto& variable_op_name_and_tensor : variable_op_name_to_tensor_) {
const auto& variable_op_name = variable_op_name_and_tensor.first;
const auto& variable_tensor = variable_op_name_and_tensor.second;
const std::string variable_filename = model_path_ + "/" + variable_op_name + "/out";
const std::string buffer = [&]() {
std::ifstream variable_file(variable_filename, std::ios::binary);
CHECK(variable_file.is_open());
std::stringstream ss;
ss << variable_file.rdbuf();
return ss.str();
}();
const auto& callback = [&](uint64_t of_blob_ptr) {
CHECK_JUST(of::BlobBufferCopyUtil<void>::From(
of_blob_ptr, buffer.data(),
variable_tensor->shape()->elem_cnt()
* of::GetSizeOfDataType(variable_tensor->dtype()->data_type())));
};
JUST(of::one::SyncAccessTensorWithTimeOut(variable_tensor, callback, "mut"));
}
const auto& pair = Unzip(variable_op_name_to_tensor_);
JUST(of::FillVariableTensorMgr(pair.first, pair.second));
return of::Maybe<void>::Ok();
}
of::Maybe<void> Graph::GraphImpl::RegisterTensors(const std::vector<Tensor>& inputs) {
{
std::vector<std::string> input_op_names(inputs.size());
std::vector<std::shared_ptr<of::one::Tensor>> input_tensors(inputs.size());
for (const auto& input_info : input_infos_) {
size_t index = input_info.second.input_output_index_;
input_op_names[index] = input_info.first;
input_tensors[index] = inputs.at(index).tensor_;
}
JUST(graph_->RegisterInputOpNamesAndTensors(input_op_names, input_tensors));
}
{
const auto& pair = Unzip(output_name_to_tensor_);
const std::vector<std::string>& output_op_names = pair.first;
const std::vector<std::shared_ptr<of::one::Tensor>>& output_tensors = pair.second;
JUST(graph_->RegisterOutputOpNamesAndTensors(output_op_names, output_tensors));
output_tensor_tuple_ = ConvertToTensorTuple(output_tensors);
}
{
const auto& t = of::DumpVariableTensorMgr();
const std::vector<std::string>& variable_op_names = std::get<0>(t);
const std::vector<std::shared_ptr<of::one::Tensor>>& variable_tensors = std::get<1>(t);
JUST(graph_->RegisterVariableOpNamesAndTensors(variable_op_names, variable_tensors));
parameter_tensor_tuple_ = ConvertToTensorTuple(variable_tensors);
}
return of::Maybe<void>::Ok();
}
Graph::GraphImpl::~GraphImpl() { of::vm::ClusterSync().GetOrThrow(); }
} // namespace oneflow_api
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_GRAPH_H_
#define ONEFLOW_API_CPP_GRAPH_H_
#include "dtype.h"
#include "shape.h"
#include "device.h"
#include "ivalue.h"
#include "tensor.h"
#include <cstddef>
#include <string>
#include <functional>
#include <unordered_map>
namespace oneflow {
class NNGraph;
} // namespace oneflow
namespace oneflow_api {
struct InputOutputAttribute {
InputOutputAttribute(DType datatype, const Shape& input_output_shape, size_t input_output_index)
: datatype_(datatype),
input_output_shape_(input_output_shape),
input_output_index_(input_output_index) {}
InputOutputAttribute() : InputOutputAttribute(DType::kInvalidDataType, Shape(), 0) {}
DType datatype_;
Shape input_output_shape_;
size_t input_output_index_;
};
using InputOutputInfos = std::unordered_map<std::string, InputOutputAttribute>;
class Graph {
public:
explicit Graph(const std::string& model_path, const Device& device = Device("cpu"));
~Graph();
Graph(const Graph& graph) = delete;
Graph(Graph&& graph) noexcept;
Graph& operator=(const Graph& graph) = delete;
Graph& operator=(Graph&& graph) noexcept;
InputOutputInfos GetInputInfos();
InputOutputInfos GetOutputInfos();
IValue Forward(const IValue& inputs);
void set_batch_size(int batch_size);
void RegisterJobPass(const std::function<std::string(const std::string& job)>& pass_fn);
static Graph Load(const std::string& model_path, const Device& device = Device("cpu"));
private:
class GraphImpl;
std::unique_ptr<GraphImpl> graph_;
};
} // namespace oneflow_api
#endif // ONEFLOW_API_CPP_GRAPH_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/cpp/framework/ivalue.h"
#include <glog/logging.h>
namespace oneflow_api {
namespace of = oneflow;
std::ostream& operator<<(std::ostream& os, const IValue::Tag& tag) {
os << static_cast<int>(tag);
return os;
}
int64_t IValue::ToInt() const {
CHECK_EQ(tag_, Tag::kInt) << "Current value is not an int.";
return payload_.i.v_int;
}
double IValue::ToDouble() const {
CHECK_EQ(tag_, Tag::kDouble) << "Current value is not a double.";
return payload_.i.v_double;
}
bool IValue::ToBool() const {
CHECK_EQ(tag_, Tag::kBool) << "Current value is not a bool.";
return payload_.i.v_bool;
}
const Tensor& IValue::ToTensor() const {
CHECK_EQ(tag_, Tag::kTensor) << "Current value is not a tensor.";
return payload_.v_tensor;
}
const std::vector<Tensor>& IValue::ToTensorVector() const {
CHECK_EQ(tag_, Tag::kTensorVector) << "Current value is not a vector of tensor.";
return payload_.v_tensor_vector;
}
} // namespace oneflow_api
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_
#define ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_
#include <cstdint>
#include <memory>
#include <vector>
#include "tensor.h"
namespace oneflow_api {
class IValue {
public:
IValue() : tag_(IValue::Tag::kNone) {}
explicit IValue(int value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; }
explicit IValue(int64_t value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; }
explicit IValue(double value) : tag_(IValue::Tag::kDouble) { payload_.i.v_double = value; }
explicit IValue(bool value) : tag_(IValue::Tag::kBool) { payload_.i.v_bool = value; }
IValue(const Tensor& value) : tag_(IValue::Tag::kTensor) { // NOLINT
new (&payload_.v_tensor) Tensor(value);
}
IValue(Tensor&& value) : tag_(IValue::Tag::kTensor) { // NOLINT
new (&payload_.v_tensor) Tensor(std::move(value));
}
IValue(const std::vector<Tensor>& value) : tag_(IValue::Tag::kTensorVector) { // NOLINT
new (&payload_.v_tensor_vector) std::vector<Tensor>(value);
}
IValue(std::vector<Tensor>&& value) : tag_(IValue::Tag::kTensorVector) { // NOLINT
new (&payload_.v_tensor_vector) std::vector<Tensor>(std::move(value));
}
IValue(const IValue& value) : tag_(value.tag_) {
if (IsTensor()) {
new (&payload_.v_tensor) Tensor(value.payload_.v_tensor);
} else if (IsTensorVector()) {
new (&payload_.v_tensor_vector) std::vector<Tensor>(value.payload_.v_tensor_vector);
} else {
payload_.i = value.payload_.i;
}
}
IValue(IValue&& value) noexcept : tag_(value.tag_) { MoveFrom(std::move(value)); }
IValue& operator=(const IValue& value) {
if (&value == this) { return *this; }
this->tag_ = value.tag_;
*this = IValue(value);
return *this;
}
IValue& operator=(IValue&& value) noexcept {
if (&value == this) { return *this; }
Destory();
this->tag_ = value.tag_;
MoveFrom(std::move(value));
return *this;
}
~IValue() { Destory(); }
bool IsNone() const { return tag_ == Tag::kNone; }
bool IsInt() const { return tag_ == Tag::kInt; }
bool IsDouble() const { return tag_ == Tag::kDouble; }
bool IsBool() const { return tag_ == Tag::kBool; }
bool IsTensor() const { return tag_ == Tag::kTensor; }
bool IsTensorVector() const { return tag_ == Tag::kTensorVector; }
int64_t ToInt() const;
double ToDouble() const;
bool ToBool() const;
const Tensor& ToTensor() const;
const std::vector<Tensor>& ToTensorVector() const;
private:
enum class Tag { kNone = 0, kInt = 1, kDouble = 2, kBool = 3, kTensor = 4, kTensorVector = 5 };
friend std::ostream& operator<<(std::ostream&, const Tag&);
union Payload { // NOLINT
union InternalPayload {
InternalPayload() : v_int(0) {}
int64_t v_int;
double v_double;
bool v_bool;
} i;
Tensor v_tensor;
std::vector<Tensor> v_tensor_vector;
Payload() : i() {}
~Payload() {}
};
Payload payload_;
Tag tag_;
inline void Destory() {
if (IsTensor()) { payload_.v_tensor.~Tensor(); }
if (IsTensorVector()) { payload_.v_tensor_vector.~vector(); }
}
inline void MoveFrom(IValue&& value) {
if (IsTensor()) {
new (&payload_.v_tensor) Tensor(std::move(value.payload_.v_tensor));
} else if (IsTensorVector()) {
new (&payload_.v_tensor_vector)
std::vector<Tensor>(std::move(value.payload_.v_tensor_vector));
} else {
payload_.i = value.payload_.i;
}
value.ClearToNone();
}
inline void ClearToNone() {
Destory();
payload_.i.v_int = 0;
tag_ = Tag::kNone;
}
};
} // namespace oneflow_api
#endif // ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/api/cpp/framework/shape.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/shape_vec.h"
namespace oneflow_api {
namespace of = oneflow;
namespace {
of::DimVector ToOneflowDimVcetor(const std::vector<int64_t>& dim_vec) {
return of::DimVector(dim_vec.begin(), dim_vec.end());
}
} // namespace
Shape::Shape() : shape_(std::make_shared<of::Shape>(of::Shape({0}))) {}
Shape::Shape(const std::vector<int64_t>& dim_vec)
: shape_(std::make_shared<of::Shape>(ToOneflowDimVcetor(dim_vec))) {}
Shape::Shape(const std::initializer_list<int64_t>& dim_vec)
: shape_(std::make_shared<of::Shape>(dim_vec)) {}
Shape& Shape::operator=(const Shape& shape) {
this->shape_.reset();
this->shape_ = shape.shape_;
return *this;
}
bool Shape::operator==(const Shape& rhs) const { return *shape_ == *rhs.shape_; }
bool Shape::operator!=(const Shape& rhs) const { return !(*this == rhs); }
int64_t Shape::elem_cnt() const { return shape_->elem_cnt(); }
int64_t Shape::At(int64_t index) const { return shape_->At(index); }
void Shape::Set(int64_t index, int64_t val) { shape_->Set(index, val); }
int64_t Shape::NumAxes() const { return shape_->NumAxes(); }
int64_t Shape::Count(int64_t begin_axis, int64_t end_axis) const {
return shape_->Count(begin_axis, end_axis);
}
int64_t Shape::Count(int64_t begin_axis) const { return shape_->Count(begin_axis); }
std::ostream& operator<<(std::ostream& os, const Shape& shape) {
os << shape.shape_->DebugStr();
return os;
}
} // namespace oneflow_api
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment