tp_communicator.cc 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tp_communicator.cc
 * \brief Tensorpipe Communicator for DGL distributed training.
 */

#include "tp_communicator.h"

#include <time.h>
#include <unistd.h>

#include <future>
#include <memory>
#include <utility>

#include "../rpc.h"

namespace dgl {
namespace rpc {

using namespace tensorpipe;

23
24
25
26
27
28
29
30
31
32
33
bool TPSender::ConnectReceiver(const std::string &addr, int recv_id) {
  if (pipes_.find(recv_id) != pipes_.end()) {
    LOG(WARNING) << "Duplicate recv_id[" << recv_id << "]. Ignoring...";
    return true;
  }
  std::shared_ptr<Pipe> pipe;
  pipe = context->connect(addr);
  auto done = std::make_shared<std::promise<bool>>();
  tensorpipe::Message tpmsg;
  tpmsg.metadata = "dglconnect";
  pipe->write(tpmsg, [done](const tensorpipe::Error &error) {
34
    done->set_value(!error);
35
36
  });
  if (!done->get_future().get()) {
37
    DLOG(WARNING) << "Failed to connect to receiver[" << addr << "].";
38
    return false;
39
  }
40
  pipes_[recv_id] = pipe;
41
42
43
  return true;
}

44
void TPSender::Send(const RPCMessage &msg, int recv_id) {
45
46
  auto pipe = pipes_[recv_id];
  tensorpipe::Message tp_msg;
47
  std::string *zerocopy_blob_ptr = &tp_msg.metadata;
48
49
50
  StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
  zc_write_strm.Write(msg);
  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
51
  zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count),
52
53
54
55
56
                            sizeof(int32_t));
  tp_msg.tensors.resize(nonempty_ndarray_count);
  // Hold the NDArray that ensure it's valid until write operation completes
  auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
  ndarray_holder->resize(nonempty_ndarray_count);
57
58
59
  auto &buffer_list = zc_write_strm.buffer_list();
  for (size_t i = 0; i < buffer_list.size(); i++) {
    auto &ptr = buffer_list[i];
60
61
62
63
64
65
66
67
68
69
    (*ndarray_holder.get())[i] = ptr.tensor;
    tensorpipe::CpuBuffer cpu_buffer;
    cpu_buffer.ptr = ptr.data;
    tp_msg.tensors[i].buffer = cpu_buffer;
    tp_msg.tensors[i].length = ptr.size;
    if (ptr.size == 0) {
      LOG(FATAL) << "Cannot send a empty NDArray.";
    }
  }
  pipe->write(tp_msg,
70
              [ndarray_holder, recv_id](const tensorpipe::Error &error) {
71
72
73
74
75
76
77
                if (error) {
                  LOG(FATAL) << "Failed to send message to " << recv_id
                             << ". Details: " << error.what();
                }
              });
}

78
79
80
81
82
83
void TPSender::Finalize() {
  for (auto &&p : pipes_) {
    p.second->close();
  }
  pipes_.clear();
}
84

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
void TPReceiver::Finalize() {
  listener_->close();
  for (auto &&p : pipes_) {
    p.second->close();
  }
  pipes_.clear();
}

bool TPReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
  if (listener_) {
    LOG(WARNING) << "TPReceiver::Wait() has been called already. Ignoring...";
    return true;
  }
  LOG(INFO) << "TPReceiver starts to wait on [" << addr << "].";
  listener_ = context->listen({addr});
  listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
    OnAccepted(error, pipe);
  });
  while (blocking && (num_sender != num_connected_)) {
104
105
106
107
  }
  return true;
}

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
  if (error) {
    if (error.isOfType<ListenerClosedError>()) {
      // Expected.
    } else {
      LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what();
    }
    return;
  }

  // Accept the next connection request
  listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
    OnAccepted(error, pipe);
  });

  // read the handshake message: "dglconnect"
  pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
    if (error) {
126
      LOG(ERROR) << "Unexpected error when reading from accepted pipe: " << error.what();
127
128
129
130
131
132
133
134
135
136
137
      return;
    }
    Allocation allocation;
    pipe->read(allocation, [](const Error &error) {});
    CHECK(descriptor.metadata == "dglconnect") << "Invalid connect message.";
    pipes_[num_connected_] = pipe;
    ReceiveFromPipe(pipe, queue_);
    ++num_connected_;
  });
}

138
139
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
                                 std::shared_ptr<RPCMessageQueue> queue) {
140
  pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error,
141
142
143
144
145
146
147
148
149
150
151
                                                        Descriptor descriptor) {
    if (error) {
      // Error may happen when the pipe is closed
      return;
    }
    Allocation allocation;
    CHECK_EQ(descriptor.payloads.size(), 0) << "Invalid DGL RPC Message";

    int tensorsize = descriptor.tensors.size();
    if (tensorsize > 0) {
      allocation.tensors.resize(tensorsize);
152
      for (size_t i = 0; i < descriptor.tensors.size(); i++) {
153
154
155
156
157
        tensorpipe::CpuBuffer cpu_buffer;
        cpu_buffer.ptr = new char[descriptor.tensors[i].length];
        allocation.tensors[i].buffer = cpu_buffer;
      }
    }
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    pipe->read(allocation, [allocation, descriptor = std::move(descriptor),
                            queue = std::move(queue),
                            pipe](const Error &error) {
      if (error) {
        // Because we always have a read event posted to the epoll,
        // Therefore when pipe is closed, error will be raised.
        // But this error is expected.
        // Other error is not expected. But we cannot identify the error with
        // each Other for now. Thus here we skip handling for all errors
        return;
      }

      char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
      std::vector<void *> buffer_list(descriptor.tensors.size());
      for (size_t i = 0; i < descriptor.tensors.size(); i++) {
        buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
      }
      StreamWithBuffer zc_read_strm(
176
177
          meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
          buffer_list);
178
179
180
181
182
      RPCMessage msg;
      zc_read_strm.Read(&msg);
      queue->push(msg);
      TPReceiver::ReceiveFromPipe(pipe, queue);
    });
183
184
185
  });
}

186
void TPReceiver::Recv(RPCMessage *msg) { *msg = std::move(queue_->pop()); }
187
188
189

}  // namespace rpc
}  // namespace dgl