Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b0eacb5b
Commit
b0eacb5b
authored
Apr 18, 2025
by
zhuwenwen
Browse files
add K100AI custom allreduce
parent
58fc3e31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
38 deletions
+230
-38
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+51
-24
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+161
-6
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+13
-8
vllm/envs.py
vllm/envs.py
+5
-0
No files found.
csrc/custom_all_reduce.cu
View file @
b0eacb5b
...
...
@@ -14,14 +14,14 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
fully_connected
)
{
int
world_size
=
fake_ipc_ptrs
.
size
();
if
(
world_size
>
8
)
if
(
world_size
>
16
)
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
if
(
world_size
%
2
!=
0
)
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
if
(
rank
<
0
||
rank
>=
world_size
)
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
vllm
::
Signal
*
ipc_ptrs
[
8
];
vllm
::
Signal
*
ipc_ptrs
[
16
];
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
vllm
::
Signal
*>
(
fake_ipc_ptrs
[
i
]);
}
...
...
@@ -78,29 +78,56 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
if
(
fa
->
fully_connected_
)
{
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
else
{
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce_pcie
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce_pcie
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce_pcie
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
// #endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
#endif
default:
throw
std
::
runtime_error
(
"custom allreduce only supports float32, float16 and bfloat16"
);
}
}
...
...
@@ -113,7 +140,7 @@ int64_t meta_size() { return sizeof(vllm::Signal); }
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
void
*
ipc_ptrs
[
8
];
void
*
ipc_ptrs
[
16
];
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
fake_ipc_ptrs
[
i
]);
}
...
...
csrc/custom_all_reduce.cuh
View file @
b0eacb5b
...
...
@@ -51,17 +51,17 @@ using FlagType = uint32_t;
// waiting for counter. We use alternating counter array to avoid this
// possibility.
struct
Signal
{
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
16
];
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
16
];
alignas
(
128
)
FlagType
_flag
[
kMaxBlocks
];
// incremental flags for each rank
};
struct
__align__
(
16
)
RankData
{
const
void
*
ptrs
[
8
];
const
void
*
ptrs
[
16
];
};
struct
__align__
(
16
)
RankSignals
{
Signal
*
signals
[
8
];
Signal
*
signals
[
16
];
};
// like std::array, but aligned
...
...
@@ -103,7 +103,7 @@ DINLINE half& assign_add(half& a, half b) {
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
//
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
template
<
>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
...
...
@@ -113,7 +113,7 @@ DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a
=
__hadd
(
a
,
b
);
return
a
;
}
#endif
//
#endif
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
...
...
@@ -372,6 +372,84 @@ __global__ void __launch_bounds__(512, 1)
}
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage_pcie
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
,
uint32_t
**
curr_hdp_reg
,
int
world_size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
if
(
threadIdx
.
x
==
1
)
{
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
__atomic_store_n
(
curr_hdp_reg
[
i
],
0x1
,
__ATOMIC_RELAXED
);
}
}
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage_pcie
(
RankData
*
_dp
,
RankSignals
sg
,
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
,
uint32_t
**
curr_hdp_reg
,
int
world_size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
int
part
=
size
/
ngpus
;
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
if
(
threadIdx
.
x
==
1
)
{
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
__atomic_store_n
(
curr_hdp_reg
[
i
],
0x1
,
__ATOMIC_RELAXED
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from
// all ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
}
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
static_assert
(
sizeof
(
IPC_KEY
)
==
sizeof
(
cudaIpcMemHandle_t
));
static_assert
(
alignof
(
IPC_KEY
)
==
alignof
(
cudaIpcMemHandle_t
));
...
...
@@ -409,6 +487,7 @@ class CustomAllreduce {
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
uint32_t
**
dev_curr_hdp_reg
;
/**
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
...
...
@@ -431,6 +510,12 @@ class CustomAllreduce {
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
sg_
.
signals
[
i
]
=
signals
[
i
];
}
if
(
!
fully_connected
)
{
cudaMalloc
((
void
**
)
&
dev_curr_hdp_reg
,
world_size_
*
sizeof
(
uint32_t
*
));
for
(
int
i
=
0
;
i
<
world_size_
;
++
i
)
{
hipDeviceGetAttribute
((
int
*
)
&
dev_curr_hdp_reg
[
i
],
hipDeviceAttributeHdpMemFlushCntl
,
i
);
}
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
...
...
@@ -522,6 +607,75 @@ class CustomAllreduce {
graph_unreg_buffers_
.
clear
();
}
template
<
typename
T
>
void
allreduce_pcie
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
ptrs
=
d_rank_data_base_
+
graph_unreg_buffers_
.
size
();
graph_unreg_buffers_
.
push_back
(
input
);
}
else
{
auto
it
=
buffers_
.
find
(
input
);
if
(
it
==
buffers_
.
end
())
throw
std
::
runtime_error
(
"buffer address "
+
std
::
to_string
(
reinterpret_cast
<
uint64_t
>
(
input
))
+
" is not registered!"
);
ptrs
=
it
->
second
;
}
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
std
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size, dev_curr_hdp_reg, world_size_) ;
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
if ((world_size_ <= 4 && bytes < 128 * 8192) || \
(world_size_ <= 8 && bytes < 8 * 8192)) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \
KL(ngpus, cross_device_reduce_2stage_pcie); \
} \
} \
break; \
}
switch
(
world_size_
)
{
REDUCE_CASE
(
2
)
REDUCE_CASE
(
4
)
REDUCE_CASE
(
6
)
REDUCE_CASE
(
8
)
REDUCE_CASE
(
16
)
default:
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8,16). Actual "
"num "
"gpus = "
+
std
::
to_string
(
world_size_
));
}
#undef REDUCE_CASE
#undef KL
}
/**
* Performs allreduce, assuming input has already been registered.
*
...
...
@@ -602,6 +756,7 @@ class CustomAllreduce {
for
(
auto
[
_
,
ptr
]
:
ipc_handles_
)
{
CUDACHECK
(
cudaIpcCloseMemHandle
(
ptr
));
}
cudaFree
(
dev_curr_hdp_reg
);
}
};
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
b0eacb5b
...
...
@@ -3,6 +3,7 @@
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
,
Union
import
os
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
...
...
@@ -47,7 +48,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
,
16
]
# max_size: max supported allreduce size
def
__init__
(
self
,
...
...
@@ -135,11 +136,17 @@ class CustomAllreduce:
# if world_size > 2 and not fully_connected:
if
not
fully_connected
:
max_size
=
32
*
8192
*
2
if
not
envs
.
VLLM_PCIE_USE_CUSTOM_ALLREDUCE
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction."
)
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
...
...
@@ -223,9 +230,7 @@ class CustomAllreduce:
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
fully_connected
:
return
inp_size
<
self
.
max_size
return
False
return
inp_size
<
self
.
max_size
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
...
...
vllm/envs.py
View file @
b0eacb5b
...
...
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
VLLM_USE_TC_PAGED_ATTN
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_SPEC_DECODE_EAGER
:
bool
=
False
VLLM_PCIE_USE_CUSTOM_ALLREDUCE
:
bool
=
False
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_FLASH_ATTN_VERSION
:
Optional
[
int
]
=
None
LOCAL_RANK
:
int
=
0
...
...
@@ -286,6 +287,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will disable the draft model in cudagraph mode.
"VLLM_SPEC_DECODE_EAGER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_SPEC_DECODE_EAGER"
,
"0"
))),
# flag to control vllm to use optimized kernels
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE"
:
lambda
:
bool
(
int
(
os
.
environ
.
get
(
"VLLM_PCIE_USE_CUSTOM_ALLREDUCE"
,
"0"
))),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment