mscclpp_allreduce.cu 5.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>

#include "mscclpp_allreduce.cuh"

enum MscclContextSelection {
  MSCCL1NODELL = 1,
  MSCCL2NODELL = 2,
};

class MscclContext {
 public:
  MscclContextSelection selection_;
  std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
  std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
  MscclContext(MscclContextSelection selection) : selection_(selection) {}
  template <typename T>
  void allreduce(
      cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
    if (selection_ == MSCCL1NODELL) {
      msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
    } else if (selection_ == MSCCL2NODELL) {
      msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
    }
  }
};

using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));

torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
  auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
  auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
  std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
  return tensor;
}

// Function to convert vector of int32_t back to array of uint8_t
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
  mscclpp::UniqueId unique_id;
  std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
  return unique_id;
}

torch::Tensor mscclpp_generate_unique_id() {
  mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
  return _unique_id2tensor(unique_id);
}

fptr_t mscclpp_init_context(
    const torch::Tensor& unique_id,
    const int64_t rank,
    const int64_t world_size,
    torch::Tensor& scratch,
    torch::Tensor& put_buffer,
    const int64_t nranks_per_node,
    const std::vector<int64_t>& rank_to_node,
    const std::vector<int64_t>& rank_to_ib,
    const int64_t context_selection) {
  MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
  mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
  if (context_selection == MSCCL1NODELL) {
    void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
    const size_t scratch_bytes = scratch.numel() * scratch.element_size();
    context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
        uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
  } else if (context_selection == MSCCL2NODELL) {
    void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
    const size_t scratch_bytes = scratch.numel() * scratch.element_size();
    void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
    const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
    context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
        uid,
        rank,
        world_size,
        scratch_ptr,
        scratch_bytes,
        put_buffer_ptr,
        put_buffer_bytes,
        nranks_per_node,
        rank_to_node,
        rank_to_ib);
  } else {
    throw std::runtime_error("invalid context selection");
  }
  return (fptr_t)context_ptr;
}

bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
  return t.is_contiguous() ||
         (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
  MscclContext* context = reinterpret_cast<MscclContext*>(_context);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
  auto stream = c10::cuda::getCurrentCUDAStream().stream();

  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
  TORCH_CHECK_EQ(inp.numel(), out.numel());
  TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
  TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
  switch (out.scalar_type()) {
    case at::ScalarType::Float: {
      context->allreduce<float>(
          stream,
          reinterpret_cast<float*>(inp.data_ptr()),
          reinterpret_cast<float*>(out.data_ptr()),
          inp.numel(),
          nthreads,
          nblocks);
      break;
    }
    case at::ScalarType::Half: {
      context->allreduce<half>(
          stream,
          reinterpret_cast<half*>(inp.data_ptr()),
          reinterpret_cast<half*>(out.data_ptr()),
          inp.numel(),
          nthreads,
          nblocks);
      break;
    }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
    case at::ScalarType::BFloat16: {
      context->allreduce<__nv_bfloat16>(
          stream,
          reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
          reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
          inp.numel(),
          nthreads,
          nblocks);
      break;
    }
#endif
    default:
      throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
  }
}