tp_communicator.cc 6.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*!
 *  Copyright (c) 2019 by Contributors
 * \file tp_communicator.cc
 * \brief Tensorpipe Communicator for DGL distributed training.
 */

#include "tp_communicator.h"

#include <time.h>
#include <unistd.h>

#include <future>
#include <memory>
#include <utility>

#include "../rpc.h"

namespace dgl {
namespace rpc {

using namespace tensorpipe;

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
bool TPSender::ConnectReceiver(const std::string &addr, int recv_id) {
  if (pipes_.find(recv_id) != pipes_.end()) {
    LOG(WARNING) << "Duplicate recv_id[" << recv_id << "]. Ignoring...";
    return true;
  }
  std::shared_ptr<Pipe> pipe;
  pipe = context->connect(addr);
  auto done = std::make_shared<std::promise<bool>>();
  tensorpipe::Message tpmsg;
  tpmsg.metadata = "dglconnect";
  pipe->write(tpmsg, [done](const tensorpipe::Error &error) {
    if (error) {
      LOG(WARNING) << "Error occurred when write to pipe: " << error.what();
      done->set_value(false);
    } else {
      done->set_value(true);
39
    }
40
41
42
43
  });
  if (!done->get_future().get()) {
    LOG(WARNING) << "Failed to connect to receiver[" << addr << "].";
    return false;
44
  }
45
  pipes_[recv_id] = pipe;
46
47
48
  return true;
}

49
void TPSender::Send(const RPCMessage &msg, int recv_id) {
50
51
  auto pipe = pipes_[recv_id];
  tensorpipe::Message tp_msg;
52
  std::string *zerocopy_blob_ptr = &tp_msg.metadata;
53
54
55
  StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
  zc_write_strm.Write(msg);
  int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
56
  zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count),
57
58
59
60
61
                            sizeof(int32_t));
  tp_msg.tensors.resize(nonempty_ndarray_count);
  // Hold the NDArray that ensure it's valid until write operation completes
  auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
  ndarray_holder->resize(nonempty_ndarray_count);
62
63
64
  auto &buffer_list = zc_write_strm.buffer_list();
  for (size_t i = 0; i < buffer_list.size(); i++) {
    auto &ptr = buffer_list[i];
65
66
67
68
69
70
71
72
73
74
    (*ndarray_holder.get())[i] = ptr.tensor;
    tensorpipe::CpuBuffer cpu_buffer;
    cpu_buffer.ptr = ptr.data;
    tp_msg.tensors[i].buffer = cpu_buffer;
    tp_msg.tensors[i].length = ptr.size;
    if (ptr.size == 0) {
      LOG(FATAL) << "Cannot send a empty NDArray.";
    }
  }
  pipe->write(tp_msg,
75
              [ndarray_holder, recv_id](const tensorpipe::Error &error) {
76
77
78
79
80
81
82
                if (error) {
                  LOG(FATAL) << "Failed to send message to " << recv_id
                             << ". Details: " << error.what();
                }
              });
}

83
84
85
86
87
88
void TPSender::Finalize() {
  for (auto &&p : pipes_) {
    p.second->close();
  }
  pipes_.clear();
}
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
void TPReceiver::Finalize() {
  listener_->close();
  for (auto &&p : pipes_) {
    p.second->close();
  }
  pipes_.clear();
}

bool TPReceiver::Wait(const std::string &addr, int num_sender, bool blocking) {
  if (listener_) {
    LOG(WARNING) << "TPReceiver::Wait() has been called already. Ignoring...";
    return true;
  }
  LOG(INFO) << "TPReceiver starts to wait on [" << addr << "].";
  listener_ = context->listen({addr});
  listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
    OnAccepted(error, pipe);
  });
  while (blocking && (num_sender != num_connected_)) {
109
110
111
112
  }
  return true;
}

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
  if (error) {
    if (error.isOfType<ListenerClosedError>()) {
      // Expected.
    } else {
      LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what();
    }
    return;
  }

  // Accept the next connection request
  listener_->accept([this](const Error &error, std::shared_ptr<Pipe> pipe) {
    OnAccepted(error, pipe);
  });

  // read the handshake message: "dglconnect"
  pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
    if (error) {
      LOG(WARNING) << "Unexpected error when reading from accepted pipe: " << error.what();
      return;
    }
    Allocation allocation;
    pipe->read(allocation, [](const Error &error) {});
    CHECK(descriptor.metadata == "dglconnect") << "Invalid connect message.";
    pipes_[num_connected_] = pipe;
    ReceiveFromPipe(pipe, queue_);
    ++num_connected_;
  });
}

143
144
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
                                 std::shared_ptr<RPCMessageQueue> queue) {
145
  pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error,
146
147
148
149
150
151
152
153
154
155
156
                                                        Descriptor descriptor) {
    if (error) {
      // Error may happen when the pipe is closed
      return;
    }
    Allocation allocation;
    CHECK_EQ(descriptor.payloads.size(), 0) << "Invalid DGL RPC Message";

    int tensorsize = descriptor.tensors.size();
    if (tensorsize > 0) {
      allocation.tensors.resize(tensorsize);
157
      for (size_t i = 0; i < descriptor.tensors.size(); i++) {
158
159
160
161
162
        tensorpipe::CpuBuffer cpu_buffer;
        cpu_buffer.ptr = new char[descriptor.tensors[i].length];
        allocation.tensors[i].buffer = cpu_buffer;
      }
    }
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    pipe->read(allocation, [allocation, descriptor = std::move(descriptor),
                            queue = std::move(queue),
                            pipe](const Error &error) {
      if (error) {
        // Because we always have a read event posted to the epoll,
        // Therefore when pipe is closed, error will be raised.
        // But this error is expected.
        // Other error is not expected. But we cannot identify the error with
        // each Other for now. Thus here we skip handling for all errors
        return;
      }

      char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
      std::vector<void *> buffer_list(descriptor.tensors.size());
      for (size_t i = 0; i < descriptor.tensors.size(); i++) {
        buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
      }
      StreamWithBuffer zc_read_strm(
181
182
          meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
          buffer_list);
183
184
185
186
187
      RPCMessage msg;
      zc_read_strm.Read(&msg);
      queue->push(msg);
      TPReceiver::ReceiveFromPipe(pipe, queue);
    });
188
189
190
  });
}

191
void TPReceiver::Recv(RPCMessage *msg) { *msg = std::move(queue_->pop()); }
192
193
194

}  // namespace rpc
}  // namespace dgl