custom_all_reduce.cu 5.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

#include "custom_all_reduce.cuh"

// fake pointer type
using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t));

fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
                      const std::vector<std::string> &handles,
                      const std::vector<int64_t> &offsets, int rank,
                      bool full_nvlink) {
  int world_size = offsets.size();
  if (world_size > 8)
    throw std::invalid_argument("world size > 8 is not supported");
  if (world_size % 2 != 0)
    throw std::invalid_argument("Odd num gpus is not supported for now");
  if (world_size != handles.size())
    throw std::invalid_argument(
        "handles length should equal to offsets length");
  if (rank < 0 || rank >= world_size)
    throw std::invalid_argument("invalid rank passed in");

  cudaIpcMemHandle_t ipc_handles[8];
  for (int i = 0; i < world_size; i++) {
    std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
  }
  return (fptr_t) new vllm::CustomAllreduce(
32
      reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
      rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}

/**
 * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
 * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
 * because it allows transpose of contiguous slice (i.e. slicing the first
 * dimension). Currently, we require this because stride information is not
 * passed into the kernels and we treat input tensors as flat.
 *
 * Examples
 * A = torch.zeros(3, 3, 3)
 * 1. A: OK
 * 2. A[1:]: OK
 * 3. A.permute(2, 0, 1): OK
 * 4. A[1:].permute(2, 0, 1): OK
 * 5. A[None].expand(2, -1, -1, -1): Not OK
 * 6. A[:, 1:, 1:]: Not OK
 */
bool _is_weak_contiguous(torch::Tensor &t) {
  return t.is_contiguous() ||
         (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
          t.numel() * t.element_size());
}

bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
                      bool full_nvlink) {
  auto inp_size = inp.numel() * inp.element_size();
  // custom allreduce requires input byte size to be multiples of 16
  if (inp_size % 16 != 0) return false;
  if (!_is_weak_contiguous(inp)) return false;
  if (world_size == 2 || full_nvlink) return inp_size <= max_size;
65
66
67
  // for 4 or more non NVLink-capable GPUs, custom allreduce provides little
  // performance improvement over NCCL.
  return false;
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
}

void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
                 cudaStream_t stream) {
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  TORCH_CHECK(_is_weak_contiguous(out));
  switch (out.scalar_type()) {
    case at::ScalarType::Float: {
      fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
                           reinterpret_cast<float *>(out.data_ptr()),
                           out.numel());
      break;
    }
    case at::ScalarType::Half: {
      fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
                          reinterpret_cast<half *>(out.data_ptr()),
                          out.numel());
      break;
    }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16: {
      fa->allreduce<nv_bfloat16>(
          stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
          reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
      break;
    }
#endif
    default:
      throw std::runtime_error(
          "custom allreduce only supports float32, float16 and bfloat16");
  }
}

void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  TORCH_CHECK_EQ(inp.numel(), out.numel());
  _all_reduce(_fa, inp, out, stream);
}

void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
                      torch::Tensor &out) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  auto stream = c10::cuda::getCurrentCUDAStream().stream();

  auto input_size = inp.numel() * inp.element_size();
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  TORCH_CHECK_EQ(inp.numel(), out.numel());
  TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
              "registered buffer is too small to contain the input");
  AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
                                input_size, cudaMemcpyDeviceToDevice, stream));
  _all_reduce(_fa, reg_buffer, out, stream);
}

void dispose(fptr_t _fa) {
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  delete fa;
}

129
int meta_size() { return sizeof(vllm::Signal); }
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

void register_buffer(fptr_t _fa, torch::Tensor &t,
                     const std::vector<std::string> &handles,
                     const std::vector<int64_t> &offsets) {
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  fa->register_buffer(handles, offsets, t.data_ptr());
}

std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
    fptr_t _fa) {
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  return fa->get_graph_buffer_ipc_meta();
}

void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
                            const std::vector<std::vector<int64_t>> &offsets) {
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
  fa->register_graph_buffers(handles, offsets);
}