Commit 21d47d0e authored by yuguo's avatar yuguo
Browse files

Oneflow 0.8 for DCU

parents
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_
#include "oneflow/core/comm_network/epoll/socket_message.h"
#ifdef OF_PLATFORM_POSIX
namespace oneflow {
class SocketReadHelper final {
public:
OF_DISALLOW_COPY_AND_MOVE(SocketReadHelper);
SocketReadHelper() = delete;
~SocketReadHelper();
SocketReadHelper(int sockfd);
void NotifyMeSocketReadable();
private:
void SwitchToMsgHeadReadHandle();
void ReadUntilSocketNotReadable();
bool MsgHeadReadHandle();
bool MsgBodyReadHandle();
bool DoCurRead(void (SocketReadHelper::*set_cur_read_done)());
void SetStatusWhenMsgHeadDone();
void SetStatusWhenMsgBodyDone();
#define MAKE_ENTRY(x, y) void SetStatusWhen##x##MsgHeadDone();
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);
#undef MAKE_ENTRY
int sockfd_;
SocketMsg cur_msg_;
bool (SocketReadHelper::*cur_read_handle_)();
char* read_ptr_;
size_t read_size_;
};
} // namespace oneflow
#endif // OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_READ_HELPER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef __linux__
#include "oneflow/core/comm_network/epoll/socket_write_helper.h"
#include "oneflow/core/comm_network/epoll/socket_memory_desc.h"
#include <sys/eventfd.h>
namespace oneflow {
SocketWriteHelper::~SocketWriteHelper() {
delete cur_msg_queue_;
cur_msg_queue_ = nullptr;
{
std::unique_lock<std::mutex> lck(pending_msg_queue_mtx_);
delete pending_msg_queue_;
pending_msg_queue_ = nullptr;
}
}
SocketWriteHelper::SocketWriteHelper(int sockfd, IOEventPoller* poller) {
sockfd_ = sockfd;
queue_not_empty_fd_ = eventfd(0, 0);
PCHECK(queue_not_empty_fd_ != -1);
poller->AddFdWithOnlyReadHandler(queue_not_empty_fd_,
std::bind(&SocketWriteHelper::ProcessQueueNotEmptyEvent, this));
cur_msg_queue_ = new std::queue<SocketMsg>;
pending_msg_queue_ = new std::queue<SocketMsg>;
cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;
write_ptr_ = nullptr;
write_size_ = 0;
}
void SocketWriteHelper::AsyncWrite(const SocketMsg& msg) {
pending_msg_queue_mtx_.lock();
bool need_send_event = pending_msg_queue_->empty();
pending_msg_queue_->push(msg);
pending_msg_queue_mtx_.unlock();
if (need_send_event) { SendQueueNotEmptyEvent(); }
}
void SocketWriteHelper::NotifyMeSocketWriteable() { WriteUntilMsgQueueEmptyOrSocketNotWriteable(); }
void SocketWriteHelper::SendQueueNotEmptyEvent() {
uint64_t event_num = 1;
PCHECK(write(queue_not_empty_fd_, &event_num, 8) == 8);
}
void SocketWriteHelper::ProcessQueueNotEmptyEvent() {
uint64_t event_num = 0;
PCHECK(read(queue_not_empty_fd_, &event_num, 8) == 8);
WriteUntilMsgQueueEmptyOrSocketNotWriteable();
}
void SocketWriteHelper::WriteUntilMsgQueueEmptyOrSocketNotWriteable() {
while ((this->*cur_write_handle_)()) {}
}
bool SocketWriteHelper::InitMsgWriteHandle() {
if (cur_msg_queue_->empty()) {
{
std::unique_lock<std::mutex> lck(pending_msg_queue_mtx_);
std::swap(cur_msg_queue_, pending_msg_queue_);
}
if (cur_msg_queue_->empty()) { return false; }
}
cur_msg_ = cur_msg_queue_->front();
cur_msg_queue_->pop();
write_ptr_ = reinterpret_cast<const char*>(&cur_msg_);
write_size_ = sizeof(cur_msg_);
cur_write_handle_ = &SocketWriteHelper::MsgHeadWriteHandle;
return true;
}
bool SocketWriteHelper::MsgHeadWriteHandle() {
return DoCurWrite(&SocketWriteHelper::SetStatusWhenMsgHeadDone);
}
bool SocketWriteHelper::MsgBodyWriteHandle() {
return DoCurWrite(&SocketWriteHelper::SetStatusWhenMsgBodyDone);
}
bool SocketWriteHelper::DoCurWrite(void (SocketWriteHelper::*set_cur_write_done)()) {
ssize_t n = write(sockfd_, write_ptr_, write_size_);
if (n == write_size_) {
(this->*set_cur_write_done)();
return true;
} else if (n >= 0) {
write_ptr_ += n;
write_size_ -= n;
return true;
} else {
CHECK_EQ(n, -1);
PCHECK(errno == EAGAIN || errno == EWOULDBLOCK);
return false;
}
}
void SocketWriteHelper::SetStatusWhenMsgHeadDone() {
switch (cur_msg_.msg_type) {
#define MAKE_ENTRY(x, y) \
case SocketMsgType::k##x: return SetStatusWhen##x##MsgHeadDone();
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);
#undef MAKE_ENTRY
default: UNIMPLEMENTED();
}
}
void SocketWriteHelper::SetStatusWhenMsgBodyDone() {
cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;
}
void SocketWriteHelper::SetStatusWhenRequestWriteMsgHeadDone() {
cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;
}
void SocketWriteHelper::SetStatusWhenRequestReadMsgHeadDone() {
const void* src_token = cur_msg_.request_read_msg.src_token;
auto src_mem_desc = static_cast<const SocketMemDesc*>(src_token);
write_ptr_ = reinterpret_cast<const char*>(src_mem_desc->mem_ptr);
write_size_ = src_mem_desc->byte_size;
cur_write_handle_ = &SocketWriteHelper::MsgBodyWriteHandle;
}
void SocketWriteHelper::SetStatusWhenActorMsgHeadDone() {
cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;
}
void SocketWriteHelper::SetStatusWhenTransportMsgHeadDone() {
cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;
}
} // namespace oneflow
#endif // __linux__
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_
#define ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_
#include "oneflow/core/comm_network/epoll/io_event_poller.h"
#include "oneflow/core/comm_network/epoll/socket_message.h"
#ifdef OF_PLATFORM_POSIX
namespace oneflow {
class SocketWriteHelper final {
public:
OF_DISALLOW_COPY_AND_MOVE(SocketWriteHelper);
SocketWriteHelper() = delete;
~SocketWriteHelper();
SocketWriteHelper(int sockfd, IOEventPoller* poller);
void AsyncWrite(const SocketMsg& msg);
void NotifyMeSocketWriteable();
private:
void SendQueueNotEmptyEvent();
void ProcessQueueNotEmptyEvent();
void WriteUntilMsgQueueEmptyOrSocketNotWriteable();
bool InitMsgWriteHandle();
bool MsgHeadWriteHandle();
bool MsgBodyWriteHandle();
bool DoCurWrite(void (SocketWriteHelper::*set_cur_write_done)());
void SetStatusWhenMsgHeadDone();
void SetStatusWhenMsgBodyDone();
#define MAKE_ENTRY(x, y) void SetStatusWhen##x##MsgHeadDone();
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ);
#undef MAKE_ENTRY
int sockfd_;
int queue_not_empty_fd_;
std::queue<SocketMsg>* cur_msg_queue_;
std::mutex pending_msg_queue_mtx_;
std::queue<SocketMsg>* pending_msg_queue_;
SocketMsg cur_msg_;
bool (SocketWriteHelper::*cur_write_handle_)();
const char* write_ptr_;
size_t write_size_;
};
} // namespace oneflow
#endif // OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_EPOLL_SOCKET_WRITE_HELPER_H_
syntax = "proto2";
package oneflow;
message IBVerbsConnectionInfo {
required uint32 lid = 1;
required uint32 qp_num = 2;
required uint64 subnet_prefix = 3;
required uint64 interface_id = 4;
required uint32 port_num = 5;
required int32 mtu = 6;
}
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"
#include "oneflow/core/lazy/actor/actor_message_bus.h"
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
namespace oneflow {
namespace {
std::string GenTokensMsgKey(int64_t machine_id) {
return "IBVerbsTokensMsg/" + std::to_string(machine_id);
}
std::string GenConnInfoKey(int64_t src_machine_id, int64_t dst_machine_id) {
return "IBVerbsConnInfo/" + std::to_string(src_machine_id) + "/" + std::to_string(dst_machine_id);
}
void IBVForkInit() {
if (ibv::IsAvailable()) {
if (ibv::wrapper.ibv_fork_init() != 0) { std::cerr << "ibv_fork_init failed\n"; }
} else {
std::cerr << "libibverbs not available, ibv_fork_init skipped\n";
}
}
void ParseUserDevicePort(std::string* device_name, int* port) {
std::string user_device_port = GetStringFromEnv("ONEFLOW_COMM_NET_IB_HCA", "");
if (user_device_port.empty()) {
*device_name = "";
*port = 0;
return;
} else {
const std::string::size_type pos = user_device_port.find(':', 0);
if (pos == std::string::npos) {
*device_name = user_device_port;
*port = 0;
return;
} else {
*device_name = user_device_port.substr(0, pos);
*port = std::strtol(user_device_port.data() + pos + 1, nullptr, 10);
return;
}
}
}
} // namespace
IBVerbsCommNet::~IBVerbsCommNet() {
while (poll_exit_flag_.test_and_set() == true) {}
poll_thread_.join();
for (IBVerbsQP* qp : qp_vec_) {
if (qp) { delete qp; }
}
CHECK_EQ(ibv::wrapper.ibv_destroy_cq(cq_), 0);
CHECK_EQ(ibv::wrapper.ibv_dealloc_pd(pd_), 0);
CHECK_EQ(ibv::wrapper.ibv_close_device(context_), 0);
}
void IBVerbsCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
CHECK_EQ(msg.user_data_size(), 0);
auto* mem_desc = reinterpret_cast<IBVerbsMemDesc*>(msg.regst()->comm_net_token());
CHECK(mem_desc != nullptr);
IBVerbsCommNetRMADesc rma_desc{};
rma_desc.mem_ptr = reinterpret_cast<uint64_t>(mem_desc->mem_ptr());
rma_desc.mem_size = mem_desc->mem_size();
rma_desc.mr_rkey = mem_desc->mr()->rkey;
static_assert(sizeof(IBVerbsCommNetRMADesc) <= kActorMsgUserDataMaxSize, "");
new_msg.AddUserData(sizeof(IBVerbsCommNetRMADesc), &rma_desc);
}
qp_vec_.at(dst_machine_id)->PostSendRequest(new_msg);
}
void IBVerbsCommNet::RecvActorMsg(const ActorMsg& msg) {
ActorMsg new_msg = msg;
if (msg.IsDataRegstMsgToConsumer()) {
std::lock_guard<std::mutex> lock(remote_regst2rma_desc_mutex_);
auto& desc = remote_regst2rma_desc_[std::make_pair(msg.src_actor_id(),
reinterpret_cast<uint64_t>(msg.regst()))];
if (!desc) { desc.reset(new IBVerbsCommNetRMADesc); }
CHECK_EQ(msg.user_data_size(), sizeof(IBVerbsCommNetRMADesc));
std::memcpy(desc.get(), msg.user_data(), sizeof(IBVerbsCommNetRMADesc));
new_msg.set_comm_net_token(desc.get());
}
Singleton<ActorMsgBus>::Get()->SendMsgWithoutCommNet(new_msg);
}
IBVerbsCommNet::IBVerbsCommNet() : CommNetIf(), poll_exit_flag_(ATOMIC_FLAG_INIT) {
int num_device;
ibv_device** device_list = ibv::wrapper.ibv_get_device_list(&num_device);
CHECK_GT(num_device, 0) << "No IB device found";
PCHECK(device_list);
std::string user_device;
int user_port;
ParseUserDevicePort(&user_device, &user_port);
ibv_device* device = nullptr;
if (user_device.empty()) {
device = device_list[0];
} else {
for (int i = 0; i < num_device; ++i) {
if (device_list[i]->name == user_device) {
device = device_list[i];
break;
}
}
CHECK(device != nullptr) << "No IB device match " << user_device;
}
context_ = ibv::wrapper.ibv_open_device(device);
CHECK(context_);
ibv::wrapper.ibv_free_device_list(device_list);
pd_ = ibv::wrapper.ibv_alloc_pd(context_);
CHECK(pd_);
ibv_device_attr device_attr{};
CHECK_EQ(ibv::wrapper.ibv_query_device(context_, &device_attr), 0);
cq_ = ibv::wrapper.ibv_create_cq(context_, device_attr.max_cqe, nullptr, nullptr, 0);
CHECK(cq_);
ibv_port_attr port_attr{};
const uint8_t port = user_port == 0 ? 1 : user_port;
CHECK_EQ(ibv::wrapper.ibv_query_port_wrap(context_, port, &port_attr), 0);
ibv_gid gid{};
const int64_t gid_index = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_GID_INDEX", 0);
CHECK_EQ(ibv::wrapper.ibv_query_gid(context_, port, gid_index, &gid), 0);
VLOG(1) << "Using IB device " << device->name << " port " << static_cast<int32_t>(port)
<< " gid index " << gid_index;
int64_t this_machine_id = GlobalProcessCtx::Rank();
qp_vec_.assign(Singleton<ResourceDesc, ForEnv>::Get()->process_ranks().size(), nullptr);
for (int64_t peer_id : peer_machine_id()) {
IBVerbsQP* cur_qp = new IBVerbsQP(context_, pd_, port_attr, port, cq_, cq_);
qp_vec_.at(peer_id) = cur_qp;
IBVerbsConnectionInfo conn_info;
conn_info.set_lid(port_attr.lid);
conn_info.set_qp_num(cur_qp->qp_num());
conn_info.set_subnet_prefix(gid.global.subnet_prefix);
conn_info.set_interface_id(gid.global.interface_id);
conn_info.set_port_num(port);
conn_info.set_mtu(static_cast<int>(port_attr.active_mtu));
Singleton<CtrlClient>::Get()->PushKV(GenConnInfoKey(this_machine_id, peer_id), conn_info);
}
for (int64_t peer_id : peer_machine_id()) {
IBVerbsConnectionInfo conn_info;
Singleton<CtrlClient>::Get()->PullKV(GenConnInfoKey(peer_id, this_machine_id), &conn_info);
if (conn_info.lid() == 0) {
VLOG(2) << "Connecting to peer " << peer_id << " port " << conn_info.port_num() << " qpn "
<< conn_info.qp_num() << " gid index " << gid_index << " spn "
<< conn_info.subnet_prefix() << " iid " << conn_info.interface_id() << " mtu "
<< conn_info.mtu();
} else {
VLOG(2) << "Connecting to peer " << peer_id << " port " << conn_info.port_num() << " qpn "
<< conn_info.qp_num() << " lid " << conn_info.interface_id() << " mtu "
<< conn_info.mtu();
}
qp_vec_.at(peer_id)->Connect(conn_info);
VLOG(1) << "Connected to peer " << peer_id;
}
OF_ENV_BARRIER();
for (int64_t peer_id : peer_machine_id()) {
qp_vec_.at(peer_id)->PostAllRecvRequest();
Singleton<CtrlClient>::Get()->ClearKV(GenConnInfoKey(this_machine_id, peer_id));
}
OF_ENV_BARRIER();
poll_thread_ = std::thread(&IBVerbsCommNet::PollCQ, this);
OF_ENV_BARRIER();
}
void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
void* dst_token) {
qp_vec_.at(src_machine_id)
->PostReadRequest(*reinterpret_cast<IBVerbsCommNetRMADesc*>(src_token),
*static_cast<const IBVerbsMemDesc*>(dst_token), read_id);
}
void IBVerbsCommNet::PollCQ() {
std::vector<ibv_wc> wc_vec(max_poll_wc_num_);
while (poll_exit_flag_.test_and_set() == false) {
poll_exit_flag_.clear();
int32_t found_wc_num = ibv_poll_cq(cq_, max_poll_wc_num_, wc_vec.data());
CHECK_GE(found_wc_num, 0);
FOR_RANGE(int32_t, i, 0, found_wc_num) {
const ibv_wc& wc = wc_vec.at(i);
CHECK_EQ(wc.status, IBV_WC_SUCCESS) << wc.opcode;
WorkRequestId* wr_id = reinterpret_cast<WorkRequestId*>(wc.wr_id);
IBVerbsQP* qp = wr_id->qp;
switch (wc.opcode) {
case IBV_WC_RDMA_READ: {
qp->ReadDone(wr_id);
break;
}
case IBV_WC_SEND: {
qp->SendDone(wr_id);
break;
}
case IBV_WC_RECV: {
qp->RecvDone(wr_id);
break;
}
default: UNIMPLEMENTED();
}
}
}
}
const int32_t IBVerbsCommNet::max_poll_wc_num_ = 32;
COMMAND(IBVForkInit());
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_
#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_
#include "oneflow/core/common/platform.h"
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h"
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
#include <netdb.h>
#include <arpa/inet.h>
namespace oneflow {
struct IBVerbsCommNetRMADesc {
uint64_t mem_ptr;
uint64_t mem_size;
uint32_t mr_rkey;
};
class IBVerbsCommNet final : public CommNetIf<IBVerbsMemDesc> {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsCommNet);
~IBVerbsCommNet();
void SendActorMsg(int64_t dst_machine_id, const ActorMsg& msg) override;
void RecvActorMsg(const ActorMsg& msg);
private:
friend class Singleton<IBVerbsCommNet>;
IBVerbsCommNet();
IBVerbsMemDesc* NewMemDesc(void* ptr, size_t byte_size) override {
return new IBVerbsMemDesc(pd_, ptr, byte_size);
}
void DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) override;
void PollCQ();
static const int32_t max_poll_wc_num_;
ibv_context* context_;
ibv_pd* pd_;
ibv_cq* cq_;
std::vector<IBVerbsQP*> qp_vec_;
std::atomic_flag poll_exit_flag_;
std::thread poll_thread_;
HashMap<std::pair<int64_t, uint64_t>, std::shared_ptr<IBVerbsCommNetRMADesc>>
remote_regst2rma_desc_;
std::mutex remote_regst2rma_desc_mutex_;
};
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_COMM_NETWORK_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
namespace oneflow {
IBVerbsMemDesc::IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size)
: mem_ptr_(mem_ptr), mem_size_(byte_size) {
mr_ = ibv::wrapper.ibv_reg_mr_wrap(
pd, mem_ptr, byte_size,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ);
CHECK(mr_);
}
IBVerbsMemDesc::~IBVerbsMemDesc() { CHECK_EQ(ibv::wrapper.ibv_dereg_mr(mr_), 0); }
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_
#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_
#include "oneflow/core/common/platform.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs.pb.h"
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
#include <infiniband/verbs.h>
namespace oneflow {
class IBVerbsMemDesc final {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsMemDesc);
IBVerbsMemDesc() = delete;
IBVerbsMemDesc(ibv_pd* pd, void* mem_ptr, size_t byte_size);
~IBVerbsMemDesc();
void* mem_ptr() const { return mem_ptr_; }
size_t mem_size() const { return mem_size_; }
const ibv_mr* mr() const { return mr_; }
private:
ibv_mr* mr_;
void* mem_ptr_;
uint64_t mem_size_;
};
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_MEMORY_DESC_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/comm_network/ibverbs/ibverbs_qp.h"
#include "oneflow/core/comm_network/comm_network.h"
#include "oneflow/core/lazy/actor/actor_message_bus.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/platform/include/ibv.h"
#include "oneflow/core/comm_network/ibverbs/ibverbs_comm_network.h"
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
namespace oneflow {
namespace {
constexpr uint32_t kDefaultQueueDepth = 1024;
constexpr uint64_t kDefaultMemBlockSize = 8388608; // 8M
} // namespace
IBVerbsQP::IBVerbsQP(ibv_context* ctx, ibv_pd* pd, const struct ibv_port_attr& port_attr,
uint8_t port_num, ibv_cq* send_cq, ibv_cq* recv_cq) {
// ctx_, pd_
ctx_ = ctx;
pd_ = pd;
port_num_ = port_num;
// qp_
ibv_device_attr device_attr{};
CHECK_EQ(ibv::wrapper.ibv_query_device(ctx, &device_attr), 0);
const int64_t user_queue_depth =
ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_QUEUE_DEPTH", kDefaultQueueDepth);
const uint32_t queue_depth = std::min<uint32_t>(device_attr.max_qp_wr, user_queue_depth);
ibv_qp_init_attr qp_init_attr{};
qp_init_attr.qp_context = nullptr;
qp_init_attr.send_cq = send_cq;
qp_init_attr.recv_cq = recv_cq;
qp_init_attr.srq = nullptr;
qp_init_attr.cap.max_send_wr = queue_depth;
qp_init_attr.cap.max_recv_wr = queue_depth;
qp_init_attr.cap.max_send_sge = 1;
qp_init_attr.cap.max_recv_sge = 1;
qp_init_attr.cap.max_inline_data = 0;
qp_init_attr.qp_type = IBV_QPT_RC;
qp_init_attr.sq_sig_all = 1;
qp_ = ibv::wrapper.ibv_create_qp(pd, &qp_init_attr);
CHECK(qp_);
// recv_msg_buf_
recv_msg_buf_.assign(queue_depth, nullptr);
FOR_RANGE(size_t, i, 0, recv_msg_buf_.size()) { recv_msg_buf_.at(i) = new ActorMsgMR(pd_); }
// send_msg_buf_
CHECK(send_msg_buf_.empty());
num_outstanding_send_wr_ = 0;
max_outstanding_send_wr_ = queue_depth;
read_block_size_ =
ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_MEM_BLOCK_SIZE", kDefaultMemBlockSize);
mtu_ = static_cast<int32_t>(port_attr.active_mtu);
}
IBVerbsQP::~IBVerbsQP() {
CHECK_EQ(ibv::wrapper.ibv_destroy_qp(qp_), 0);
while (send_msg_buf_.empty() == false) {
delete send_msg_buf_.front();
send_msg_buf_.pop();
}
for (ActorMsgMR* msg_mr : recv_msg_buf_) { delete msg_mr; }
}
void IBVerbsQP::Connect(const IBVerbsConnectionInfo& peer_info) {
ibv_qp_attr qp_attr{};
// IBV_QPS_INIT
memset(&qp_attr, 0, sizeof(ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_INIT;
// TODO(liujuncheng): Make pkey_index configurable
qp_attr.pkey_index = 0;
qp_attr.port_num = port_num_;
qp_attr.qp_access_flags =
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ;
CHECK_EQ(ibv::wrapper.ibv_modify_qp(
qp_, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS),
0);
// IBV_QPS_RTR
memset(&qp_attr, 0, sizeof(ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_RTR;
// TODO(liujuncheng): Make sl configurable;
qp_attr.ah_attr.sl = 0;
qp_attr.ah_attr.src_path_bits = 0;
if (peer_info.lid() == 0) {
qp_attr.ah_attr.is_global = 1;
qp_attr.ah_attr.grh.dgid.global.subnet_prefix = peer_info.subnet_prefix();
qp_attr.ah_attr.grh.dgid.global.interface_id = peer_info.interface_id();
qp_attr.ah_attr.grh.flow_label = 0;
const int64_t gid_index = ParseIntegerFromEnv("ONEFLOW_COMM_NET_IB_GID_INDEX", 0);
qp_attr.ah_attr.grh.sgid_index = gid_index;
qp_attr.ah_attr.grh.hop_limit = 255;
// TODO(liujuncheng): Make traffic_class configurable;
qp_attr.ah_attr.grh.traffic_class = 0;
} else {
qp_attr.ah_attr.is_global = 0;
qp_attr.ah_attr.dlid = peer_info.lid();
}
qp_attr.ah_attr.port_num = peer_info.port_num();
qp_attr.path_mtu = static_cast<ibv_mtu>(std::min(peer_info.mtu(), mtu_));
qp_attr.dest_qp_num = peer_info.qp_num();
qp_attr.rq_psn = 0;
qp_attr.max_dest_rd_atomic = 1;
qp_attr.min_rnr_timer = 12;
CHECK_EQ(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr,
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN
| IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC
| IBV_QP_MIN_RNR_TIMER),
0);
// IBV_QPS_RTS
memset(&qp_attr, 0, sizeof(ibv_qp_attr));
qp_attr.qp_state = IBV_QPS_RTS;
qp_attr.sq_psn = 0;
qp_attr.max_rd_atomic = 1;
qp_attr.retry_cnt = 7;
qp_attr.rnr_retry = 7;
qp_attr.timeout = 14;
CHECK_EQ(ibv::wrapper.ibv_modify_qp(qp_, &qp_attr,
IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC
| IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_TIMEOUT),
0);
}
void IBVerbsQP::PostAllRecvRequest() {
for (ActorMsgMR* msg_mr : recv_msg_buf_) { PostRecvRequest(msg_mr); }
}
void IBVerbsQP::PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem,
const IBVerbsMemDesc& local_mem, void* read_id) {
CHECK_EQ(remote_mem.mem_size, local_mem.mem_size());
WorkRequestId* wr_id = NewWorkRequestId();
const size_t block_num = RoundUp(remote_mem.mem_size, read_block_size_) / read_block_size_;
wr_id->outstanding_sge_cnt = static_cast<int32_t>(block_num);
wr_id->read_id = read_id;
FOR_RANGE(size_t, i, 0, block_num) {
ibv_send_wr wr{};
ibv_sge sge{};
sge.addr = reinterpret_cast<uint64_t>(local_mem.mem_ptr()) + i * read_block_size_;
sge.length = std::min(read_block_size_, local_mem.mem_size() - i * read_block_size_);
sge.lkey = local_mem.mr()->lkey;
wr.wr_id = reinterpret_cast<uint64_t>(wr_id);
wr.next = nullptr;
wr.sg_list = &sge;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_READ;
wr.send_flags = 0;
wr.imm_data = 0;
wr.wr.rdma.remote_addr = remote_mem.mem_ptr + i * read_block_size_;
wr.wr.rdma.rkey = remote_mem.mr_rkey;
EnqueuePostSendReadWR(wr, sge);
}
}
void IBVerbsQP::PostSendRequest(const ActorMsg& msg) {
ActorMsgMR* msg_mr = GetOneSendMsgMRFromBuf();
msg_mr->set_msg(msg);
WorkRequestId* wr_id = NewWorkRequestId();
wr_id->msg_mr = msg_mr;
ibv_send_wr wr{};
ibv_sge sge{};
sge.addr = reinterpret_cast<uint64_t>(msg_mr->mem_desc().mem_ptr());
sge.length = msg_mr->mem_desc().mem_size();
sge.lkey = msg_mr->mem_desc().mr()->lkey;
wr.wr_id = reinterpret_cast<uint64_t>(wr_id);
wr.next = nullptr;
wr.sg_list = &sge;
wr.num_sge = 1;
wr.opcode = IBV_WR_SEND;
wr.send_flags = 0;
wr.imm_data = 0;
memset(&(wr.wr), 0, sizeof(wr.wr));
EnqueuePostSendReadWR(wr, sge);
}
void IBVerbsQP::EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge) {
std::unique_lock<std::mutex> pending_send_wr_lock_(pending_send_wr_mutex_);
if (num_outstanding_send_wr_ < max_outstanding_send_wr_) {
num_outstanding_send_wr_++;
ibv_send_wr* bad_wr = nullptr;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
} else {
std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::make_pair(wr, sge);
pending_send_wr_queue_.push(ibv_send_wr_sge);
}
}
void IBVerbsQP::ReadDone(WorkRequestId* wr_id) {
CHECK_GE(wr_id->outstanding_sge_cnt, 1);
wr_id->outstanding_sge_cnt -= 1;
if (wr_id->outstanding_sge_cnt == 0) {
Singleton<CommNet>::Get()->ReadDone(wr_id->read_id);
DeleteWorkRequestId(wr_id);
}
PostPendingSendWR();
}
void IBVerbsQP::SendDone(WorkRequestId* wr_id) {
{
std::unique_lock<std::mutex> lck(send_msg_buf_mtx_);
send_msg_buf_.push(wr_id->msg_mr);
}
DeleteWorkRequestId(wr_id);
PostPendingSendWR();
}
void IBVerbsQP::RecvDone(WorkRequestId* wr_id) {
auto* ibv_comm_net = dynamic_cast<IBVerbsCommNet*>(Singleton<CommNet>::Get());
CHECK(ibv_comm_net != nullptr);
ibv_comm_net->RecvActorMsg(wr_id->msg_mr->msg());
PostRecvRequest(wr_id->msg_mr);
DeleteWorkRequestId(wr_id);
}
void IBVerbsQP::PostPendingSendWR() {
std::unique_lock<std::mutex> pending_send_wr_lock_(pending_send_wr_mutex_);
if (pending_send_wr_queue_.empty() == false) {
std::pair<ibv_send_wr, ibv_sge> ibv_send_wr_sge = std::move(pending_send_wr_queue_.front());
ibv_send_wr wr = ibv_send_wr_sge.first;
wr.sg_list = &ibv_send_wr_sge.second;
pending_send_wr_queue_.pop();
ibv_send_wr* bad_wr = nullptr;
CHECK_EQ(ibv_post_send(qp_, &wr, &bad_wr), 0);
} else {
if (num_outstanding_send_wr_ > 0) { num_outstanding_send_wr_--; }
}
}
void IBVerbsQP::PostRecvRequest(ActorMsgMR* msg_mr) {
WorkRequestId* wr_id = NewWorkRequestId();
wr_id->msg_mr = msg_mr;
ibv_recv_wr wr{};
ibv_sge sge{};
sge.addr = reinterpret_cast<uint64_t>(msg_mr->mem_desc().mem_ptr());
sge.length = msg_mr->mem_desc().mem_size();
sge.lkey = msg_mr->mem_desc().mr()->lkey;
wr.wr_id = reinterpret_cast<uint64_t>(wr_id);
wr.next = nullptr;
wr.sg_list = &sge;
wr.num_sge = 1;
ibv_recv_wr* bad_wr = nullptr;
CHECK_EQ(ibv_post_recv(qp_, &wr, &bad_wr), 0);
}
ActorMsgMR* IBVerbsQP::GetOneSendMsgMRFromBuf() {
std::unique_lock<std::mutex> lck(send_msg_buf_mtx_);
if (send_msg_buf_.empty()) { send_msg_buf_.push(new ActorMsgMR(pd_)); }
ActorMsgMR* msg_mr = send_msg_buf_.front();
send_msg_buf_.pop();
return msg_mr;
}
WorkRequestId* IBVerbsQP::NewWorkRequestId() {
WorkRequestId* wr_id = new WorkRequestId;
wr_id->qp = this;
wr_id->outstanding_sge_cnt = 0;
wr_id->read_id = nullptr;
wr_id->msg_mr = nullptr;
return wr_id;
}
void IBVerbsQP::DeleteWorkRequestId(WorkRequestId* wr_id) {
CHECK_EQ(wr_id->qp, this);
delete wr_id;
}
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_
#define ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_
#include "oneflow/core/comm_network/ibverbs/ibverbs_memory_desc.h"
#include "oneflow/core/lazy/actor/actor_message.h"
#if defined(WITH_RDMA) && defined(OF_PLATFORM_POSIX)
namespace oneflow {
class ActorMsgMR final {
public:
OF_DISALLOW_COPY_AND_MOVE(ActorMsgMR);
ActorMsgMR() = delete;
ActorMsgMR(ibv_pd* pd) { mem_desc_.reset(new IBVerbsMemDesc(pd, &msg_, sizeof(msg_))); }
~ActorMsgMR() { mem_desc_.reset(); }
const ActorMsg& msg() const { return msg_; }
void set_msg(const ActorMsg& val) { msg_ = val; }
const IBVerbsMemDesc& mem_desc() const { return *mem_desc_; }
private:
ActorMsg msg_;
std::unique_ptr<IBVerbsMemDesc> mem_desc_;
};
class IBVerbsQP;
struct WorkRequestId {
IBVerbsQP* qp;
int32_t outstanding_sge_cnt;
void* read_id;
ActorMsgMR* msg_mr;
};
struct IBVerbsCommNetRMADesc;
class IBVerbsQP final {
public:
OF_DISALLOW_COPY_AND_MOVE(IBVerbsQP);
IBVerbsQP() = delete;
IBVerbsQP(ibv_context*, ibv_pd*, const struct ibv_port_attr&, uint8_t port_num, ibv_cq* send_cq,
ibv_cq* recv_cq);
~IBVerbsQP();
uint32_t qp_num() const { return qp_->qp_num; }
void Connect(const IBVerbsConnectionInfo& peer_info);
void PostAllRecvRequest();
void PostReadRequest(const IBVerbsCommNetRMADesc& remote_mem, const IBVerbsMemDesc& local_mem,
void* read_id);
void PostSendRequest(const ActorMsg& msg);
void ReadDone(WorkRequestId*);
void SendDone(WorkRequestId*);
void RecvDone(WorkRequestId*);
private:
void EnqueuePostSendReadWR(ibv_send_wr wr, ibv_sge sge);
void PostPendingSendWR();
WorkRequestId* NewWorkRequestId();
void DeleteWorkRequestId(WorkRequestId* wr_id);
ActorMsgMR* GetOneSendMsgMRFromBuf();
void PostRecvRequest(ActorMsgMR*);
ibv_context* ctx_;
ibv_pd* pd_;
uint8_t port_num_;
ibv_qp* qp_;
std::vector<ActorMsgMR*> recv_msg_buf_;
std::mutex send_msg_buf_mtx_;
std::queue<ActorMsgMR*> send_msg_buf_;
std::mutex pending_send_wr_mutex_;
uint32_t num_outstanding_send_wr_;
uint32_t max_outstanding_send_wr_;
std::queue<std::pair<ibv_send_wr, ibv_sge>> pending_send_wr_queue_;
size_t read_block_size_;
int32_t mtu_;
};
} // namespace oneflow
#endif // WITH_RDMA && OF_PLATFORM_POSIX
#endif // ONEFLOW_CORE_COMM_NETWORK_IBVERBS_IBVERBS_QP_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_ARRAY_REF_H_
#define ONEFLOW_CORE_COMMON_ARRAY_REF_H_
#include "llvm/ADT/ArrayRef.h"
namespace oneflow {
template<typename T>
using ArrayRef = llvm::ArrayRef<T>;
template<typename T>
using MutableArrayRef = llvm::MutableArrayRef<T>;
} // namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_
#define ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
template<typename Key, typename Base, typename... Args>
struct AutoRegistrationFactory {
public:
using Creator = std::function<Base*(Args&&...)>;
template<typename Derived>
struct RawRegisterType {
RawRegisterType(Key k) {
CHECK((AutoRegistrationFactory<Key, Base, Args...>::Get()
.mutable_creators()
->emplace(k, [](Args&&...) { return new Derived; })
.second))
<< k;
}
};
struct CreatorRegisterType {
CreatorRegisterType(Key k, Creator v) {
CHECK((AutoRegistrationFactory<Key, Base, Args...>::Get()
.mutable_creators()
->emplace(k, v)
.second))
<< k;
}
};
Base* New(Key k, Args&&... args) const {
auto creators_it = creators().find(k);
CHECK(creators_it != creators().end())
<< "Unregistered: key: " << k << " Base type name:" << typeid(Base).name()
<< " Key type name" << typeid(Key).name();
return creators_it->second(std::forward<Args>(args)...);
}
bool IsClassRegistered(Key k, Args&&... args) const {
return creators().find(k) != creators().end();
}
static AutoRegistrationFactory<Key, Base, Args...>& Get() {
static AutoRegistrationFactory<Key, Base, Args...> obj;
return obj;
}
private:
std::unique_ptr<HashMap<Key, Creator>> creators_;
bool has_creators() const { return creators_.get() != nullptr; }
const HashMap<Key, Creator>& creators() const {
CHECK(has_creators()) << "Unregistered key type: " << typeid(Key).name();
return *creators_.get();
}
HashMap<Key, Creator>* mutable_creators() {
if (!creators_) { creators_.reset(new HashMap<Key, Creator>); }
return creators_.get();
}
};
#define REGISTER_VAR_NAME OF_PP_CAT(g_registry_var, __COUNTER__)
#define REGISTER_CLASS(Key, k, Base, Derived) \
static AutoRegistrationFactory<Key, Base>::RawRegisterType<Derived> REGISTER_VAR_NAME(k)
#define REGISTER_CLASS_WITH_ARGS(Key, k, Base, Derived, ...) \
static AutoRegistrationFactory<Key, Base, __VA_ARGS__>::RawRegisterType<Derived> \
REGISTER_VAR_NAME(k)
#define REGISTER_CLASS_CREATOR(Key, k, Base, f, ...) \
static AutoRegistrationFactory<Key, Base, ##__VA_ARGS__>::CreatorRegisterType REGISTER_VAR_NAME( \
k, f)
template<typename Key, typename Base, typename... Args>
inline Base* NewObj(Key k, Args&&... args) {
return AutoRegistrationFactory<Key, Base, Args...>::Get().New(k, std::forward<Args>(args)...);
}
template<typename Key, typename Base, typename... Args>
inline std::unique_ptr<Base> NewObjUniquePtr(Key k, Args&&... args) {
return std::unique_ptr<Base>(
AutoRegistrationFactory<Key, Base, Args...>::Get().New(k, std::forward<Args>(args)...));
}
template<typename Key, typename Base, typename... Args>
inline bool IsClassRegistered(Key k, Args&&... args) {
return AutoRegistrationFactory<Key, Base, Args...>::Get().IsClassRegistered(
k, std::forward<Args>(args)...);
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_AUTO_REGISTRATION_FACTORY_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
BalancedSplitter::BalancedSplitter(int64_t total_num, int64_t split_num) {
base_part_size_ = total_num / split_num;
base_begin_idx_ = total_num % split_num;
split_num_ = split_num;
}
Range BalancedSplitter::At(int64_t idx) const {
CHECK_LT(idx, split_num_);
int64_t left_bound = -1;
int64_t right_bound = -1;
if (idx < base_begin_idx_) {
left_bound = (base_part_size_ + 1) * idx;
right_bound = left_bound + (base_part_size_ + 1);
} else {
left_bound =
(base_part_size_ + 1) * base_begin_idx_ + base_part_size_ * (idx - base_begin_idx_);
right_bound = left_bound + base_part_size_;
}
return Range(left_bound, right_bound);
}
Range BalancedSplitter::At(int64_t first_idx, int64_t last_idx) const {
CHECK_LE(first_idx, last_idx);
CHECK_LT(last_idx, split_num_);
Range first_range = At(first_idx);
Range last_range = At(last_idx);
return Range(first_range.begin(), last_range.end());
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
#define ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
#include <stdint.h>
#include "oneflow/core/common/range.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
// For example
// BalancedSplitter splitter(20, 6)
// the result of splitter.At
// 0 [0, 4)
// 1 [4, 8)
// 2 [8, 11)
// 3 [11, 14)
// 4 [14, 17)
// 5 [17, 20)
class BalancedSplitter final {
public:
// OF_DISALLOW_COPY_AND_MOVE(BalancedSplitter);
BalancedSplitter() = delete;
~BalancedSplitter() = default;
BalancedSplitter(int64_t total_num, int64_t split_num);
Range At(int64_t idx) const;
Range At(int64_t first_idx, int64_t last_idx) const;
private:
int64_t base_part_size_;
int64_t base_begin_idx_;
int64_t split_num_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "gtest/gtest.h"
#include "oneflow/core/common/balanced_splitter.h"
namespace oneflow {
TEST(BalancedSplitter, split_20_to_6_part) {
BalancedSplitter splitter(20, 6);
ASSERT_TRUE(splitter.At(0) == Range(0, 4));
ASSERT_TRUE(splitter.At(1) == Range(4, 8));
ASSERT_TRUE(splitter.At(2) == Range(8, 11));
ASSERT_TRUE(splitter.At(3) == Range(11, 14));
ASSERT_TRUE(splitter.At(4) == Range(14, 17));
ASSERT_TRUE(splitter.At(5) == Range(17, 20));
}
TEST(BalancedSplitter, split_2_to_3_part) {
BalancedSplitter splitter(2, 3);
ASSERT_TRUE(splitter.At(0) == Range(0, 1));
ASSERT_TRUE(splitter.At(1) == Range(1, 2));
ASSERT_TRUE(splitter.At(2) == Range(2, 2));
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BLAS_H_
#define ONEFLOW_CORE_COMMON_BLAS_H_
#include <type_traits>
#include <utility>
#ifdef WITH_CUDA
#include <cuda_fp16.h>
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_fp16.h>
#endif // WITH_ROCM
#include "oneflow/core/common/cblas.h"
#include "oneflow/core/common/preprocessor.h"
namespace oneflow {
#define BLAS_NAME_SEQ \
OF_PP_MAKE_TUPLE_SEQ(dot) \
OF_PP_MAKE_TUPLE_SEQ(swap) \
OF_PP_MAKE_TUPLE_SEQ(copy) \
OF_PP_MAKE_TUPLE_SEQ(axpy) \
OF_PP_MAKE_TUPLE_SEQ(scal) \
OF_PP_MAKE_TUPLE_SEQ(gemv) \
OF_PP_MAKE_TUPLE_SEQ(gemm) \
OF_PP_MAKE_TUPLE_SEQ(gemmBatched) \
OF_PP_MAKE_TUPLE_SEQ(gemmStridedBatched)
#define CBLAS_TEMPLATE(name) \
template<typename T, typename... Args> \
auto cblas_##name(Args&&... args) \
->typename std::enable_if<std::is_same<T, float>::value, \
decltype(cblas_##s##name(std::forward<Args>(args)...))>::type { \
return cblas_##s##name(std::forward<Args>(args)...); \
} \
template<typename T, typename... Args> \
auto cblas_##name(Args&&... args) \
->typename std::enable_if<std::is_same<T, double>::value, \
decltype(cblas_##d##name(std::forward<Args>(args)...))>::type { \
return cblas_##d##name(std::forward<Args>(args)...); \
}
OF_PP_FOR_EACH_TUPLE(CBLAS_TEMPLATE, BLAS_NAME_SEQ);
#undef CBLAS_TEMPLATE
#undef BLAS_NAME_SEQ
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BLAS_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/common/foreign_lock_helper.h"
#include "oneflow/core/common/singleton.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/env_var/env_var.h"
namespace oneflow {
int64_t BlockingCounter::Increase() {
std::unique_lock<std::mutex> lck(mtx_);
CHECK_GT(cnt_val_, 0);
cnt_val_ += 1;
return cnt_val_;
}
int64_t BlockingCounter::Decrease() {
std::unique_lock<std::mutex> lck(mtx_);
cnt_val_ -= 1;
if (cnt_val_ == 0) { cond_.notify_all(); }
return cnt_val_;
}
Maybe<void> BlockingCounter::WaitUntilCntEqualZero(size_t timeout_seconds) {
return Singleton<ForeignLockHelper>::Get()->WithScopedRelease([&, this]() -> Maybe<void> {
std::chrono::duration<size_t> seconds(timeout_seconds);
std::unique_lock<std::mutex> lck(mtx_);
CHECK_OR_RETURN(cond_.wait_for(lck, seconds, [this]() { return cnt_val_ == 0; }))
<< Error::TimeoutError();
return Maybe<void>::Ok();
});
}
void BlockingCounter::WaitForeverUntilCntEqualZero() {
CHECK_JUST(WaitUntilCntEqualZero([]() -> Maybe<bool> { return false; }));
}
Maybe<void> BlockingCounter::WaitUntilCntEqualZero(
const std::function<Maybe<bool>()>& StopWaitingAfterTimeout) {
while (true) {
auto status = TRY(WaitUntilCntEqualZero(EnvInteger<ONEFLOW_TIMEOUT_SECONDS>()));
if (status.IsOk()) { return status; }
if (!status.error()->has_timeout_error()) { return status; }
if (JUST(StopWaitingAfterTimeout())) { return status; }
}
UNIMPLEMENTED_THEN_RETURN();
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_
#define ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/maybe.h"
namespace oneflow {
class BlockingCounter final {
public:
OF_DISALLOW_COPY_AND_MOVE(BlockingCounter);
BlockingCounter() = delete;
~BlockingCounter() = default;
BlockingCounter(int64_t cnt_val) { cnt_val_ = cnt_val; }
int64_t Increase();
int64_t Decrease();
void WaitForeverUntilCntEqualZero();
Maybe<void> WaitUntilCntEqualZero(size_t timeout_seconds);
Maybe<void> WaitUntilCntEqualZero(const std::function<Maybe<bool>()>& StopWaitingAfterTimeout);
private:
std::mutex mtx_;
std::condition_variable cond_;
int64_t cnt_val_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BLOCKING_COUNTER_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_
#define ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/blocking_counter.h"
#include "oneflow/core/common/spin_counter.h"
namespace oneflow {
class BlockingThenBusy final {
public:
BlockingThenBusy(const BlockingThenBusy&) = delete;
BlockingThenBusy(BlockingThenBusy&&) = delete;
BlockingThenBusy() = delete;
explicit BlockingThenBusy(int cnt) : blocking_counter_(cnt), spin_counter_(cnt) {}
BlockingCounter* mut_blocking_counter() { return &blocking_counter_; }
SpinCounter* mut_spin_counter() { return &spin_counter_; }
Maybe<void> WaitUntilCntEqualZero(const std::function<Maybe<bool>()>& StopAfterTimeout) {
JUST(blocking_counter_.WaitUntilCntEqualZero(StopAfterTimeout));
JUST(spin_counter_.WaitUntilCntEqualZero());
return Maybe<void>::Ok();
}
private:
BlockingCounter blocking_counter_;
SpinCounter spin_counter_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BLOCKING_THEN_BUSY_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_COMMON_BUFFER_H_
#define ONEFLOW_CORE_COMMON_BUFFER_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
enum BufferStatus { kBufferStatusSuccess = 0, kBufferStatusErrorClosed, kBufferStatusEmpty };
template<typename T>
class Buffer final {
public:
OF_DISALLOW_COPY_AND_MOVE(Buffer);
Buffer(size_t max_len) : max_len_(max_len), is_closed_(false) {}
~Buffer() = default;
template<typename U>
BufferStatus Push(U&& item);
BufferStatus Pull(T* item);
BufferStatus TryReceive(T* item);
void Close();
private:
std::queue<T> queue_;
mutable std::mutex mutex_;
size_t max_len_;
bool is_closed_;
std::condition_variable cond_;
};
template<typename T>
template<typename U>
BufferStatus Buffer<T>::Push(U&& item) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return queue_.size() < max_len_ || is_closed_; });
if (is_closed_) { return kBufferStatusErrorClosed; }
queue_.push(std::forward<U>(item));
cond_.notify_one();
return kBufferStatusSuccess;
}
template<typename T>
BufferStatus Buffer<T>::Pull(T* item) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return (!queue_.empty()) || is_closed_; });
if (queue_.empty()) { return kBufferStatusErrorClosed; }
*item = std::move(queue_.front());
queue_.pop();
if (queue_.size() < max_len_) { cond_.notify_all(); }
return kBufferStatusSuccess;
}
template<typename T>
BufferStatus Buffer<T>::TryReceive(T* item) {
std::unique_lock<std::mutex> lock(mutex_);
if (queue_.empty()) { return is_closed_ ? kBufferStatusErrorClosed : kBufferStatusEmpty; }
*item = std::move(queue_.front());
queue_.pop();
if (queue_.size() < max_len_) { cond_.notify_all(); }
return kBufferStatusSuccess;
}
template<typename T>
void Buffer<T>::Close() {
std::unique_lock<std::mutex> lock(mutex_);
is_closed_ = true;
cond_.notify_all();
}
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BUFFER_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment