Commit f5f65d24 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-custom-eco' into 'v0.7.2-dev'

add K100AI custom allreduce

See merge request dcutoolkit/deeplearing/vllm!96
parents 645e9ec4 ec078ef1
...@@ -14,14 +14,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs, ...@@ -14,14 +14,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank, torch::Tensor& rank_data, int64_t rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = fake_ipc_ptrs.size(); int world_size = fake_ipc_ptrs.size();
if (world_size > 8) if (world_size > 16)
throw std::invalid_argument("world size > 8 is not supported"); throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0) if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now"); throw std::invalid_argument("Odd num gpus is not supported for now");
if (rank < 0 || rank >= world_size) if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in"); throw std::invalid_argument("invalid rank passed in");
vllm::Signal* ipc_ptrs[8]; vllm::Signal* ipc_ptrs[16];
for (int i = 0; i < world_size; i++) { for (int i = 0; i < world_size; i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]); ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(fake_ipc_ptrs[i]);
} }
...@@ -78,29 +78,56 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, ...@@ -78,29 +78,56 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
} else { } else {
reg_buffer = inp.data_ptr(); reg_buffer = inp.data_ptr();
} }
switch (out.scalar_type()) { if (fa->full_nvlink_) {
case at::ScalarType::Float: { switch (out.scalar_type()) {
fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer), case at::ScalarType::Float: {
reinterpret_cast<float*>(out.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(reg_buffer),
out.numel()); reinterpret_cast<float*>(out.data_ptr()),
break; out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
// #endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
} }
case at::ScalarType::Half: { } else {
fa->allreduce<half>(stream, reinterpret_cast<half*>(reg_buffer), switch (out.scalar_type()) {
reinterpret_cast<half*>(out.data_ptr()), out.numel()); case at::ScalarType::Float: {
break; fa->allreduce_pcie<float>(stream, reinterpret_cast<float*>(reg_buffer),
} reinterpret_cast<float*>(out.data_ptr()),
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) out.numel());
case at::ScalarType::BFloat16: { break;
fa->allreduce<nv_bfloat16>( }
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer), case at::ScalarType::Half: {
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel()); fa->allreduce_pcie<half>(stream, reinterpret_cast<half*>(reg_buffer),
break; reinterpret_cast<half*>(out.data_ptr()), out.numel());
break;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce_pcie<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(reg_buffer),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break;
}
// #endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
} }
#endif
default:
throw std::runtime_error(
"custom allreduce only supports float32, float16 and bfloat16");
} }
} }
...@@ -113,7 +140,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); } ...@@ -113,7 +140,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) { void register_buffer(fptr_t _fa, const std::vector<fptr_t>& fake_ipc_ptrs) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_);
void* ipc_ptrs[8]; void* ipc_ptrs[16];
for (int i = 0; i < fake_ipc_ptrs.size(); i++) { for (int i = 0; i < fake_ipc_ptrs.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]); ipc_ptrs[i] = reinterpret_cast<void*>(fake_ipc_ptrs[i]);
} }
......
...@@ -52,17 +52,17 @@ using FlagType = uint32_t; ...@@ -52,17 +52,17 @@ using FlagType = uint32_t;
// waiting for counter. We use alternating counter array to avoid this // waiting for counter. We use alternating counter array to avoid this
// possibility. // possibility.
struct Signal { struct Signal {
alignas(128) FlagType start[kMaxBlocks][8]; alignas(128) FlagType start[kMaxBlocks][16];
alignas(128) FlagType end[kMaxBlocks][8]; alignas(128) FlagType end[kMaxBlocks][16];
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
}; };
struct __align__(16) RankData { struct __align__(16) RankData {
const void* ptrs[8]; const void* ptrs[16];
}; };
struct __align__(16) RankSignals { struct __align__(16) RankSignals {
Signal* signals[8]; Signal* signals[16];
}; };
// like std::array, but aligned // like std::array, but aligned
...@@ -104,7 +104,7 @@ DINLINE half& assign_add(half& a, half b) { ...@@ -104,7 +104,7 @@ DINLINE half& assign_add(half& a, half b) {
} }
DINLINE float& assign_add(float& a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) // #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <> template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
...@@ -114,7 +114,7 @@ DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { ...@@ -114,7 +114,7 @@ DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
#endif // #endif
template <typename T, int N> template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) { DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
...@@ -373,6 +373,84 @@ __global__ void __launch_bounds__(512, 1) ...@@ -373,6 +373,84 @@ __global__ void __launch_bounds__(512, 1)
} }
} }
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage_pcie(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size,
uint32_t** curr_hdp_reg, int world_size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto dp = *_dp;
if (threadIdx.x == 1) {
for(int i = 0; i < world_size; i++) {
__atomic_store_n(curr_hdp_reg[i], 0x1, __ATOMIC_RELAXED);
}
}
start_sync<ngpus>(sg, self_sg, rank);
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, self_sg, rank);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage_pcie(RankData* _dp, RankSignals sg, Signal* self_sg,
T* __restrict__ result, int rank, int size,
uint32_t** curr_hdp_reg, int world_size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
int part = size / ngpus;
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
if (threadIdx.x == 1) {
for(int i = 0; i < world_size; i++) {
__atomic_store_n(curr_hdp_reg[i], 0x1, __ATOMIC_RELAXED);
}
}
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
start_sync<ngpus>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
}
end_sync<ngpus>(sg, self_sg, rank);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from
// all ranks.
for (int idx = tid; idx < largest_part; idx += stride) {
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
}
}
}
}
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>; using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
...@@ -409,6 +487,7 @@ class CustomAllreduce { ...@@ -409,6 +487,7 @@ class CustomAllreduce {
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
uint32_t** dev_curr_hdp_reg;
/** /**
* Signals are an array of ipc-enabled buffers from all ranks. * Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows: * For each of the buffer, the layout is as follows:
...@@ -431,6 +510,12 @@ class CustomAllreduce { ...@@ -431,6 +510,12 @@ class CustomAllreduce {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
sg_.signals[i] = signals[i]; sg_.signals[i] = signals[i];
} }
if (!full_nvlink) {
cudaMalloc((void**)&dev_curr_hdp_reg, world_size_ * sizeof(uint32_t*));
for (int i = 0; i < world_size_; ++i) {
hipDeviceGetAttribute((int*)&dev_curr_hdp_reg[i], hipDeviceAttributeHdpMemFlushCntl, i);
}
}
} }
char* open_ipc_handle(const void* ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
...@@ -522,6 +607,75 @@ class CustomAllreduce { ...@@ -522,6 +607,75 @@ class CustomAllreduce {
graph_unreg_buffers_.clear(); graph_unreg_buffers_.clear();
} }
template <typename T>
void allreduce_pcie(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = defaultBlockLimit) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
"custom allreduce currently requires input length to be multiple "
"of " +
std::to_string(d));
if (block_limit > kMaxBlocks)
throw std::runtime_error("max supported block limit is " +
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
graph_unreg_buffers_.push_back(input);
} else {
auto it = buffers_.find(input);
if (it == buffers_.end())
throw std::runtime_error(
"buffer address " +
std::to_string(reinterpret_cast<uint64_t>(input)) +
" is not registered!");
ptrs = it->second;
}
size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
if ((world_size_ <= 4 && bytes < 128 * 8192) || \
(world_size_ <= 8 && bytes < 8 * 8192)) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
KL(ngpus, cross_device_reduce_2stage_pcie); \
} \
} \
break; \
}
switch (world_size_) {
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
REDUCE_CASE(16)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8,16). Actual "
"num "
"gpus = " +
std::to_string(world_size_));
}
#undef REDUCE_CASE
#undef KL
}
/** /**
* Performs allreduce, assuming input has already been registered. * Performs allreduce, assuming input has already been registered.
* *
...@@ -587,9 +741,10 @@ class CustomAllreduce { ...@@ -587,9 +741,10 @@ class CustomAllreduce {
REDUCE_CASE(4) REDUCE_CASE(4)
REDUCE_CASE(6) REDUCE_CASE(6)
REDUCE_CASE(8) REDUCE_CASE(8)
REDUCE_CASE(16)
default: default:
throw std::runtime_error( throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual " "custom allreduce only supports num gpus in (2,4,6,8,16). Actual "
"num " "num "
"gpus = " + "gpus = " +
std::to_string(world_size_)); std::to_string(world_size_));
...@@ -602,6 +757,7 @@ class CustomAllreduce { ...@@ -602,6 +757,7 @@ class CustomAllreduce {
for (auto [_, ptr] : ipc_handles_) { for (auto [_, ptr] : ipc_handles_) {
CUDACHECK(cudaIpcCloseMemHandle(ptr)); CUDACHECK(cudaIpcCloseMemHandle(ptr));
} }
cudaFree(dev_curr_hdp_reg);
} }
}; };
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import ctypes import ctypes
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Optional, Union from typing import List, Optional, Union
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -18,6 +18,7 @@ from vllm.logger import init_logger ...@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from vllm import envs
try: try:
ops.meta_size() ops.meta_size()
custom_ar = True custom_ar = True
...@@ -50,13 +51,13 @@ def is_weak_contiguous(inp: torch.Tensor): ...@@ -50,13 +51,13 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
# max_size: max supported allreduce size # max_size: max supported allreduce size
def __init__(self, def __init__(self,
group: ProcessGroup, group: ProcessGroup,
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
max_size=8192 * 1024 * 2) -> None: max_size=8192 * 1024) -> None:
""" """
Args: Args:
group: the process group to work on. If None, it will use the group: the process group to work on. If None, it will use the
...@@ -137,11 +138,18 @@ class CustomAllreduce: ...@@ -137,11 +138,18 @@ class CustomAllreduce:
full_nvlink = current_platform.is_fully_connected_nvlink_or_xgmi( full_nvlink = current_platform.is_fully_connected_nvlink_or_xgmi(
physical_device_ids) physical_device_ids)
if not full_nvlink: if not full_nvlink:
max_size = 32 * 8192 * 2
if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
logger.warning(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly.")
return
logger.warning( logger.warning(
"Custom allreduce is disabled because it's not supported on" "We are using PCIe's custom allreduce."
" more than two PCIe-only GPUs. To silence this warning, " "If the performance is poor, we can add "
"specify disable_custom_all_reduce=True explicitly.") "--disable-custom-all-reduce in the instruction.")
return
# test P2P capability, this checks software/cudaruntime support # test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time # this is expensive to compute at the first time
# then we cache the result # then we cache the result
...@@ -259,9 +267,7 @@ class CustomAllreduce: ...@@ -259,9 +267,7 @@ class CustomAllreduce:
return False return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides # for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL. # little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink: return inp_size < self.max_size
return inp_size < self.max_size
return False
def all_reduce(self, def all_reduce(self,
inp: torch.Tensor, inp: torch.Tensor,
......
...@@ -18,6 +18,7 @@ if TYPE_CHECKING: ...@@ -18,6 +18,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE:bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
...@@ -246,6 +247,10 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -246,6 +247,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE":
lambda: bool(int(os.environ.get("VLLM_PCIE_USE_CUSTOM_ALLREDUCE", "0"))),
# flag to control vllm to use optimized tc paged attn kernels # flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN": "VLLM_USE_TC_PAGED_ATTN":
lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "True").lower() in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment