custom_all_reduce.cu 7.02 KB
Newer Older
1
2
3
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
4
#include <torch/all.h>
5
6
7

#include "custom_all_reduce.cuh"

8
9
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
10
using fptr_t = int64_t;
11
static_assert(sizeof(void*) == sizeof(fptr_t));
12

13
14
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
                      torch::Tensor& rank_data, int64_t rank,
15
                      bool fully_connected) {
16
  int world_size = fake_ipc_ptrs.size();
17
18
19
20
21
22
23
  if (world_size > 8)
    throw std::invalid_argument("world size > 8 is not supported");
  if (world_size % 2 != 0)
    throw std::invalid_argument("Odd num gpus is not supported for now");
  if (rank < 0 || rank >= world_size)
    throw std::invalid_argument("invalid rank passed in");

24
  vllm::Signal* ipc_ptrs[8];
25
  for (int i = 0; i < world_size; i++) {
26
    ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
27
  }
28
29
  return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
                                            rank_data.numel(), rank, world_size,
30
                                            fully_connected);
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
}

/**
 * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
 * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
 * because it allows transpose of contiguous slice (i.e. slicing the first
 * dimension). Currently, we require this because stride information is not
 * passed into the kernels and we treat input tensors as flat.
 *
 * Examples
 * A = torch.zeros(3, 3, 3)
 * 1. A: OK
 * 2. A[1:]: OK
 * 3. A.permute(2, 0, 1): OK
 * 4. A[1:].permute(2, 0, 1): OK
 * 5. A[None].expand(2, -1, -1, -1): Not OK
 * 6. A[:, 1:, 1:]: Not OK
 */
49
bool _is_weak_contiguous(torch::Tensor& t) {
50
51
52
53
54
  return t.is_contiguous() ||
         (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
          t.numel() * t.element_size());
}

55
56
57
58
59
60
61
62
63
/**
 * Performs an out-of-place allreduce and stores result in out.
 *
 * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
 * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
 * copied into _reg_buffer.
 */
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
                fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
64
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
65
66
67
68
69
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  auto stream = c10::cuda::getCurrentCUDAStream().stream();

  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  TORCH_CHECK_EQ(inp.numel(), out.numel());
70
  TORCH_CHECK(_is_weak_contiguous(out));
71
72
73
74
75
76
77
78
79
80
  TORCH_CHECK(_is_weak_contiguous(inp));
  auto input_size = inp.numel() * inp.element_size();
  auto reg_buffer = reinterpret_cast<void*>(_reg_buffer);
  if (reg_buffer) {
    TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes);
    AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size,
                                  cudaMemcpyDeviceToDevice, stream));
  } else {
    reg_buffer = inp.data_ptr();
  }
81
82
  switch (out.scalar_type()) {
    case at::ScalarType::Float: {
83
      fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
84
                           reinterpret_cast<float*>(out.data_ptr()),
85
86
87
88
                           out.numel());
      break;
    }
    case at::ScalarType::Half: {
89
      fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
90
                          reinterpret_cast<half*>(out.data_ptr()), out.numel());
91
92
93
94
95
      break;
    }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16: {
      fa->allreduce<nv_bfloat16>(
96
          stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
97
          reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
98
99
100
101
102
103
104
105
106
107
      break;
    }
#endif
    default:
      throw std::runtime_error(
          "custom allreduce only supports float32, float16 and bfloat16");
  }
}

void dispose(fptr_t _fa) {
108
  delete reinterpret_cast<vllm::CustomAllreduce*>(_fa);
109
110
}

111
int64_t meta_size() { return sizeof(vllm::Signal); }
112

113
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
114
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
115
116
117
118
119
120
  TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
  void* ipc_ptrs[8];
  for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
    ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
  }
  fa->register_buffer(ipc_ptrs);
121
122
}

123
124
125
// Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_graph_buffer_ipc_meta(fptr_t _fa) {
126
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
127
128
129
  auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
  std::vector<int64_t> bytes(handle.begin(), handle.end());
  return std::make_tuple(bytes, offsets);
130
131
}

132
133
134
// Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers(fptr_t _fa,
                            const std::vector<std::vector<int64_t>>& handles,
135
136
                            const std::vector<std::vector<int64_t>>& offsets) {
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
137
138
139
140
141
142
143
  std::vector<std::string> bytes;
  bytes.reserve(handles.size());
  for (int i = 0; i < handles.size(); i++) {
    bytes.emplace_back(handles[i].begin(), handles[i].end());
  }
  bytes.reserve(handles.size());
  fa->register_graph_buffers(bytes, offsets);
144
}
zhuwenwen's avatar
zhuwenwen committed
145
146

std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
147
148
149
150
151
152
153
154
155
    int64_t size) {
  auto device_index = c10::cuda::current_device();
  at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
  void* buffer;
  cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
  AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));

  // Allocate buffer
zhuwenwen's avatar
zhuwenwen committed
156
#if defined(USE_ROCM)
157
158
159
  // data buffers need to be "uncached" for signal on MI200
  AT_CUDA_CHECK(
      hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
zhuwenwen's avatar
zhuwenwen committed
160
#else
161
  AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
zhuwenwen's avatar
zhuwenwen committed
162
#endif
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
  AT_CUDA_CHECK(cudaStreamSynchronize(stream));
  AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));

  // Create IPC memhandle for the allocated buffer.
  // Will use it in open_mem_handle.
  auto options =
      torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
  auto handle =
      torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
  AT_CUDA_CHECK(
      cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));

  return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
zhuwenwen's avatar
zhuwenwen committed
177
}
178

zhuwenwen's avatar
zhuwenwen committed
179
180
181
182
183
184
185
186
187
188
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
  void* ipc_ptr;
  AT_CUDA_CHECK(cudaIpcOpenMemHandle(
      (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
      cudaIpcMemLazyEnablePeerAccess));
  return reinterpret_cast<fptr_t>(ipc_ptr);
}

void free_shared_buffer(fptr_t buffer) {
  AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
zhuwenwen's avatar
zhuwenwen committed
189
}