#include #include #include #include #include "custom_all_reduce.cuh" // Fake pointer type, must match fptr_t type in ops.h. // We use this type alias to indicate when pointers are passed in as int64_t. using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, bool fully_connected) { int world_size = fake_ipc_ptrs.size(); if (world_size > 16) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); vllm::Signal* ipc_ptrs[16]; for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, fully_connected); } /** * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() * because it allows transpose of contiguous slice (i.e. slicing the first * dimension). Currently, we require this because stride information is not * passed into the kernels and we treat input tensors as flat. * * Examples * A = torch.zeros(3, 3, 3) * 1. A: OK * 2. A[1:]: OK * 3. A.permute(2, 0, 1): OK * 4. A[1:].permute(2, 0, 1): OK * 5. A[None].expand(2, -1, -1, -1): Not OK * 6. A[:, 1:, 1:]: Not OK */ bool _is_weak_contiguous(torch::Tensor& t) { return t.is_contiguous() || (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); } /** * Performs an out-of-place allreduce and stores result in out. * * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ void all_reduce_fuse_norm(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t hidden_size, torch::Tensor& residual, torch::Tensor& rms_weight, double eps, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(inp)); TORCH_CHECK(_is_weak_contiguous(residual)); TORCH_CHECK(_is_weak_contiguous(rms_weight)); int token_num = inp.numel() / hidden_size; auto input_size = inp.numel() * inp.element_size(); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream)); } else { reg_buffer = inp.data_ptr(); } switch (out.scalar_type()) { case at::ScalarType::Float: { fa->allreduce_fuse_norm(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()),out.numel(), token_num, hidden_size, reinterpret_cast(residual.data_ptr()), reinterpret_cast(rms_weight.data_ptr()), (float)eps); break; } case at::ScalarType::Half: { fa->allreduce_fuse_norm(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()),out.numel(), token_num, hidden_size, reinterpret_cast(residual.data_ptr()), reinterpret_cast(rms_weight.data_ptr()), (float)eps); break; } // #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce_fuse_norm(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()),out.numel(), token_num, hidden_size, reinterpret_cast(residual.data_ptr()), reinterpret_cast(rms_weight.data_ptr()), (float)eps); break; } // #endif default: throw std::runtime_error( "custom allreduce only supports float32, float16 and bfloat16"); } } template void allreduce_fuse_norm_quant_dispath(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int hidden_size,torch::Tensor& rms_weight, double eps, torch::Tensor& scales, torch::Tensor& norm_out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes, std::optional residual) { auto fa = reinterpret_cast(_fa); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(_is_weak_contiguous(inp)); int token_num = inp.numel() / hidden_size; auto input_size = inp.numel() * inp.element_size(); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream)); } else { reg_buffer = inp.data_ptr(); } auto wt_ptr = reinterpret_cast(rms_weight.data_ptr()); if (wt_ptr % 16 != 0) { throw std::runtime_error( "custom allreduce currently requires wt_ptr % 16 " "of " + std::to_string(wt_ptr % 16)); } if (fa->fully_connected_) { if (residual.has_value()) { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] { fa->allreduce_fuse_norm_quant (stream, reinterpret_cast(reg_buffer), out.data_ptr(), out.numel(), token_num, hidden_size, residual->data_ptr(), rms_weight.data_ptr(), norm_out.data_ptr(), eps, scales.data_ptr()); }); } else { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "fa->allreduce_fuse_norm_quant", [&] { fa->allreduce_fuse_norm_quant (stream, reinterpret_cast(reg_buffer), out.data_ptr(), out.numel(), token_num, hidden_size, nullptr, rms_weight.data_ptr(), norm_out.data_ptr(), eps, scales.data_ptr()); }); } } else { throw std::runtime_error( "custom allreduce only supports fully_connected"); } } void all_reduce_fuse_norm_quant(fptr_t fa, torch::Tensor& inp, torch::Tensor& out, int64_t hidden_size,torch::Tensor& rms_weight, double eps, torch::Tensor& scales, torch::Tensor& norm_out, fptr_t reg_buffer, int64_t reg_buffer_sz_bytes, std::optional residual, bool update_input) { static c10::ScalarType kFp8Type = c10::ScalarType::Float8_e4m3fn; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(scales.dtype() == torch::kFloat32); TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(out.is_contiguous() && inp.is_contiguous()); VLLM_DISPATCH_FLOATING_TYPES( inp.scalar_type(), "allreduce_fuse_norm_quant_dispath", [&] { if (update_input) allreduce_fuse_norm_quant_dispath( fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out, reg_buffer, reg_buffer_sz_bytes, residual); else allreduce_fuse_norm_quant_dispath( fa, inp, out, hidden_size, rms_weight, eps, scales, norm_out, reg_buffer, reg_buffer_sz_bytes, residual); }); } void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { auto fa = reinterpret_cast(_fa); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); auto stream = c10::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(inp)); auto input_size = inp.numel() * inp.element_size(); auto reg_buffer = reinterpret_cast(_reg_buffer); if (reg_buffer) { TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream)); } else { reg_buffer = inp.data_ptr(); } if (fa->fully_connected_) { switch (out.scalar_type()) { case at::ScalarType::Float: { fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { fa->allreduce(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } // #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } // #endif default: throw std::runtime_error( "custom allreduce only supports float32, float16 and bfloat16"); } } else { switch (out.scalar_type()) { case at::ScalarType::Float: { fa->allreduce_pcie(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { fa->allreduce_pcie(stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } // #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce_pcie( stream, reinterpret_cast(reg_buffer), reinterpret_cast(out.data_ptr()), out.numel()); break; } // #endif default: throw std::runtime_error( "custom allreduce only supports float32, float16 and bfloat16"); } } } void dispose(fptr_t _fa) { delete reinterpret_cast(_fa); } int64_t meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { auto fa = reinterpret_cast(_fa); TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); void* ipc_ptrs[16]; for (int i = 0; i < fake_ipc_ptrs.size(); i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } fa->register_buffer(ipc_ptrs); } // Use vector to represent byte data for python binding compatibility. std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { auto fa = reinterpret_cast(_fa); auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); std::vector bytes(handle.begin(), handle.end()); return std::make_tuple(bytes, offsets); } // Use vector to represent byte data for python binding compatibility. void register_graph_buffers(fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) { auto fa = reinterpret_cast(_fa); std::vector bytes; bytes.reserve(handles.size()); for (int i = 0; i < handles.size(); i++) { bytes.emplace_back(handles[i].begin(), handles[i].end()); } bytes.reserve(handles.size()); fa->register_graph_buffers(bytes, offsets); } std::tuple allocate_shared_buffer_and_handle( int64_t size) { auto device_index = c10::cuda::current_device(); at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); void* buffer; cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; auto stream = c10::cuda::getCurrentCUDAStream().stream(); AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Allocate buffer #if defined(USE_ROCM) // data buffers need to be "uncached" for signal on MI200 AT_CUDA_CHECK( hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); #else AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size)); #endif AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); // Create IPC memhandle for the allocated buffer. // Will use it in open_mem_handle. auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); auto handle = torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); AT_CUDA_CHECK( cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer)); return std::make_tuple(reinterpret_cast(buffer), handle); } fptr_t open_mem_handle(torch::Tensor& mem_handle) { void* ipc_ptr; AT_CUDA_CHECK(cudaIpcOpenMemHandle( (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()), cudaIpcMemLazyEnablePeerAccess)); return reinterpret_cast(ipc_ptr); } void free_shared_buffer(fptr_t buffer) { AT_CUDA_CHECK(cudaFree(reinterpret_cast(buffer))); }