Unverified Commit d25398cb authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fix custom_allreduce namespace (#6039)

parent 8a828666
...@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat ...@@ -18,11 +18,11 @@ init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, torch::Tensor& rank_dat
if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in");
vllm::Signal* ipc_ptrs[8]; sglang::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]); ipc_ptrs[i] = reinterpret_cast<sglang::Signal*>(fake_ipc_ptrs[i]);
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new sglang::CustomAllreduce(
ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink); ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink);
} }
...@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) { ...@@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
* copied into _reg_buffer. * copied into _reg_buffer.
*/ */
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
...@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_ ...@@ -98,15 +98,15 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
delete reinterpret_cast<vllm::CustomAllreduce*>(_fa); delete reinterpret_cast<sglang::CustomAllreduce*>(_fa);
} }
int64_t meta_size() { int64_t meta_size() {
return sizeof(vllm::Signal); return sizeof(sglang::Signal);
} }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) { void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8]; void* ipc_ptrs[8];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) { for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
...@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) { ...@@ -117,7 +117,7 @@ void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
// Use vector<int64_t> to represent byte data for python binding compatibility. // Use vector<int64_t> to represent byte data for python binding compatibility.
std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) { std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); auto [handle, offsets] = fa->get_graph_buffer_ipc_meta();
std::vector<int64_t> bytes(handle.begin(), handle.end()); std::vector<int64_t> bytes(handle.begin(), handle.end());
return std::make_tuple(bytes, offsets); return std::make_tuple(bytes, offsets);
...@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta ...@@ -126,7 +126,7 @@ std::tuple<std::vector<int64_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta
// Use vector<int64_t> to represent byte data for python binding compatibility. // Use vector<int64_t> to represent byte data for python binding compatibility.
void register_graph_buffers( void register_graph_buffers(
fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) { fptr_t _fa, const std::vector<std::vector<int64_t>>& handles, const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
std::vector<std::string> bytes; std::vector<std::string> bytes;
bytes.reserve(handles.size()); bytes.reserve(handles.size());
for (int i = 0; i < handles.size(); i++) { for (int i = 0; i < handles.size(); i++) {
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "utils.h" #include "utils.h"
namespace vllm { namespace sglang {
constexpr int kMaxBlocks = 36; constexpr int kMaxBlocks = 36;
// Counter may overflow, but it's fine since unsigned int overflow is // Counter may overflow, but it's fine since unsigned int overflow is
...@@ -483,7 +483,7 @@ class CustomAllreduce { ...@@ -483,7 +483,7 @@ class CustomAllreduce {
/** /**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation: a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *, * template void sglang::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int); half *, int, int, int);
*/ */
} // namespace vllm } // namespace sglang
...@@ -29,8 +29,8 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, ...@@ -29,8 +29,8 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t)); std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(hipIpcMemHandle_t));
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new sglang::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(), reinterpret_cast<sglang::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
} }
...@@ -58,7 +58,7 @@ bool _is_weak_contiguous(torch::Tensor& t) { ...@@ -58,7 +58,7 @@ bool _is_weak_contiguous(torch::Tensor& t) {
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
hipStream_t stream) { hipStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
...@@ -110,22 +110,22 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, ...@@ -110,22 +110,22 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
delete fa; delete fa;
} }
int64_t meta_size() { return sizeof(vllm::Signal); } int64_t meta_size() { return sizeof(sglang::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles, const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) { const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
auto options = auto options =
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
...@@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta( ...@@ -137,7 +137,7 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<sglang::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); fa->register_graph_buffers(handles, offsets);
} }
......
...@@ -26,7 +26,7 @@ typedef __hip_bfloat16 nv_bfloat16; ...@@ -26,7 +26,7 @@ typedef __hip_bfloat16 nv_bfloat16;
} \ } \
} while (0) } while (0)
namespace vllm { namespace sglang {
constexpr int kMaxBlocks = 64; constexpr int kMaxBlocks = 64;
// note: we don't want to use atomics for signals because peer atomics are no // note: we don't want to use atomics for signals because peer atomics are no
...@@ -572,11 +572,11 @@ class CustomAllreduce { ...@@ -572,11 +572,11 @@ class CustomAllreduce {
CUDACHECK(hipIpcCloseMemHandle(ptr)); CUDACHECK(hipIpcCloseMemHandle(ptr));
} }
} }
}; // namespace vllm }; // namespace sglang
/** /**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation: a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(hipStream_t, half *, * template void sglang::CustomAllreduce::allreduce<half>(hipStream_t, half *,
half *, int, int, int); half *, int, int, int);
*/ */
} // namespace vllm } // namespace sglang
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment