Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f721096d
Unverified
Commit
f721096d
authored
Mar 21, 2024
by
Hanzhi Zhou
Committed by
GitHub
Mar 21, 2024
Browse files
[BugFix] Some fixes for custom allreduce kernels (#2760)
parent
e90fc21f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
232 additions
and
250 deletions
+232
-250
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+5
-5
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+75
-152
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+108
-76
vllm/config.py
vllm/config.py
+0
-9
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-1
vllm/model_executor/parallel_utils/custom_all_reduce.py
vllm/model_executor/parallel_utils/custom_all_reduce.py
+43
-7
No files found.
csrc/custom_all_reduce.cu
View file @
f721096d
...
...
@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std
::
memcpy
(
&
ipc_handles
[
i
],
handles
[
i
].
data
(),
sizeof
(
cudaIpcMemHandle_t
));
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
reinterpret_cast
<
vllm
::
Metadata
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
ipc_handles
,
offsets
,
rank
,
full_nvlink
);
}
...
...
@@ -62,9 +62,9 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
if
(
inp_size
%
16
!=
0
)
return
false
;
if
(
!
_is_weak_contiguous
(
inp
))
return
false
;
if
(
world_size
==
2
||
full_nvlink
)
return
inp_size
<=
max_size
;
//
4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when siz
e
//
<= 512k
return
world_size
<=
4
&&
inp_size
<=
512
*
1024
;
//
for 4 or more non NVLink-capable GPUs, custom allreduce provides littl
e
//
performance improvement over NCCL.
return
false
;
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
...
...
@@ -126,7 +126,7 @@ void dispose(fptr_t _fa) {
delete
fa
;
}
int
meta_size
()
{
return
sizeof
(
vllm
::
Metadata
);
}
int
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
...
...
csrc/custom_all_reduce.cuh
View file @
f721096d
...
...
@@ -23,29 +23,17 @@
namespace
vllm
{
constexpr
int
kMaxBlocks
=
64
;
// note: we don't want to use atomics for signals because peer atomics are no
// supported on PCIe links
struct
Signal
{
alignas
(
64
)
union
{
uint64_t
flag
;
unsigned
char
data
[
8
];
}
start
;
alignas
(
64
)
union
{
uint64_t
flag
;
unsigned
char
data
[
8
];
}
end
;
alignas
(
128
)
uint32_t
start
[
kMaxBlocks
][
8
];
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
};
struct
Metadata
{
alignas
(
128
)
Signal
sg
;
alignas
(
128
)
int
counter
;
};
static_assert
(
offsetof
(
Metadata
,
counter
)
==
128
);
static_assert
(
sizeof
(
Metadata
)
==
256
);
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
...
...
@@ -135,70 +123,49 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}
// compute flag at compile time
__host__
__device__
constexpr
uint64_t
compute_flag
(
int
ngpus
)
{
auto
m
=
std
::
numeric_limits
<
uint64_t
>::
max
();
return
m
>>
((
8
-
ngpus
)
*
8
);
}
// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Metadata
*
meta
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
constexpr
auto
FLAG
=
compute_flag
(
ngpus
);
if
(
blockIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
<
ngpus
)
// simultaneously write to the corresponding byte to all other ranks.
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
start
.
data
[
rank
]
=
255
;
else
if
(
threadIdx
.
x
==
32
)
// reset
meta
->
sg
.
end
.
flag
=
0
;
}
if
(
threadIdx
.
x
==
0
)
{
while
(
meta
->
sg
.
start
.
flag
!=
FLAG
)
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]
=
0
;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
])
;
}
__syncthreads
();
}
// This function is meant to be used as the second or the final synchronization
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Metadata
*
meta
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
constexpr
auto
FLAG
=
compute_flag
(
ngpus
);
__syncthreads
();
__shared__
int
num
;
if
(
threadIdx
.
x
==
0
)
num
=
atomicAdd
((
int
*
)
&
meta
->
counter
,
1
);
__syncthreads
();
// Only the last completing block can perform the end synchronization
// This can ensures when the final busy wait ends, all ranks must have
// finished reading each other's buffer.
if
(
num
==
gridDim
.
x
-
1
)
{
if
(
threadIdx
.
x
==
32
)
{
// reset in a different warp
meta
->
counter
=
0
;
meta
->
sg
.
start
.
flag
=
0
;
}
else
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding byte to all other ranks.
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
end
.
data
[
rank
]
=
255
;
}
// if this is the final sync, only one block needs it
// because kernel exit can serve as sync
if
constexpr
(
final_sync
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
meta
->
sg
.
end
.
flag
!=
FLAG
)
;
}
}
}
if
constexpr
(
!
final_sync
)
{
if
(
threadIdx
.
x
==
0
)
{
while
(
meta
->
sg
.
end
.
flag
!=
FLAG
)
;
}
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
if
constexpr
(
!
final_sync
)
__threadfence_system
();
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]
=
0
;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
])
;
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
...
...
@@ -214,32 +181,32 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Metadata
*
meta
,
T
*
__restrict__
result
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
start_sync
<
ngpus
>
(
sg
,
meta
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
end_sync
<
ngpus
,
true
>
(
sg
,
meta
,
rank
);
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
P
>
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
return
(
P
*
)(((
Metadata
*
)
sg
)
+
1
);
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Metadata
*
meta
,
T
*
__restrict__
result
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
...
...
@@ -248,6 +215,7 @@ __global__ void __launch_bounds__(512, 1)
int
part
=
size
/
ngpus
;
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
...
...
@@ -257,75 +225,28 @@ __global__ void __launch_bounds__(512, 1)
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
start_sync
<
ngpus
>
(
sg
,
meta
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
// Maybe TODO: replace this with per-block release-acquire
// can save about 1-2us (not a lot though)
end_sync
<
ngpus
>
(
sg
,
meta
,
rank
);
// stage 2: allgather
for
(
int
idx
=
tid
;
idx
<
part
;
idx
+=
stride
)
{
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
dst_idx
=
((
rank
+
i
)
%
ngpus
)
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
// process the last larger partition
int
remaining
=
size
-
part
*
ngpus
;
if
(
tid
<
remaining
)
{
int
dst_idx
=
tid
+
part
*
ngpus
;
((
P
*
)
result
)[
dst_idx
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
ngpus
-
1
])[
part
+
tid
];
}
// faster than this
// for (int idx = tid; idx < size; idx += stride) {
// int target_rank = idx / part;
// if (target_rank == ngpus) target_rank -= 1;
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
// }
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_half_butterfly
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Metadata
*
meta
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
auto
tmp_out
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
rank
]);
constexpr
int
hg
=
ngpus
/
2
;
// Actually not quite half butterfly.
// This is an all-to-all within each group containing half of the ranks
// followed by cross-group add. Equivalent to half butterfly when there
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
const
P
*
ptrs
[
hg
];
{
int
start
=
rank
-
rank
%
hg
;
#pragma unroll
for
(
int
i
=
0
;
i
<
hg
;
i
++
)
{
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
i
+
start
];
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
start_sync
<
ngpus
>
(
sg
,
meta
,
rank
);
for
(
int
idx
=
tid
;
idx
<
size
;
idx
+=
stride
)
{
tmp_out
[
idx
]
=
packed_reduce
<
P
,
hg
,
A
>
(
ptrs
,
idx
);
}
end_sync
<
ngpus
>
(
sg
,
meta
,
rank
);
auto
src
=
get_tmp_buf
<
P
>
(
sg
.
signals
[(
ngpus
-
1
)
-
rank
%
ngpus
]);
// do the cross group reduction
for
(
int
idx
=
tid
;
idx
<
size
;
idx
+=
stride
)
{
auto
tmp
=
tmp_out
[
idx
];
packed_assign_add
(
tmp
,
src
[
idx
]);
((
P
*
)
result
)[
idx
]
=
tmp
;
}
}
using
IPC_KEY
=
std
::
array
<
uint8_t
,
sizeof
(
cudaIpcMemHandle_t
)
>
;
...
...
@@ -341,7 +262,7 @@ class CustomAllreduce {
// below are device pointers
RankSignals
sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Metadata
*
meta
_
;
Signal
*
self_sg
_
;
// stores the registered device pointers from all ranks
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
...
...
@@ -352,32 +273,32 @@ class CustomAllreduce {
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
*
* There's a total of sizeof(
Metadata
) of prefix before the actual data,
* There's a total of sizeof(
Signal
) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce
(
Metadata
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
meta
_
(
meta
),
self_sg
_
(
meta
),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Metadata
*
rank_
meta
;
Signal
*
rank_
sg
;
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
handle
+=
offsets
[
i
];
rank_
meta
=
(
Metadata
*
)
handle
;
rank_
sg
=
(
Signal
*
)
handle
;
}
else
{
rank_
meta
=
meta
_
;
rank_
sg
=
self_sg
_
;
}
sg_
.
signals
[
i
]
=
&
rank_
meta
->
sg
;
sg_
.
signals
[
i
]
=
rank_sg
;
}
}
...
...
@@ -492,6 +413,10 @@ class CustomAllreduce {
"custom allreduce currently requires input length to be multiple "
"of "
+
std
::
to_string
(
d
));
if
(
block_limit
>
kMaxBlocks
)
throw
std
::
runtime_error
(
"max supported block limit is "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
...
...
@@ -512,9 +437,9 @@ class CustomAllreduce {
size
/=
d
;
auto
bytes
=
size
*
sizeof
(
typename
packed_t
<
T
>::
P
);
int
blocks
=
std
::
min
(
block_limit
,
(
size
+
threads
-
1
)
/
threads
);
#define KL(ngpus, name) \
name<T, ngpus>
\
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output,
rank_, size);
#define KL(ngpus, name)
\
name<T, ngpus>
<<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output,
\
rank_, size);
#define REDUCE_CASE(ngpus) \
case ngpus: { \
if (world_size_ == 2) { \
...
...
@@ -526,8 +451,6 @@ class CustomAllreduce {
} else { \
KL(ngpus, cross_device_reduce_2stage); \
} \
} else { \
KL(ngpus, cross_device_reduce_half_butterfly); \
} \
break; \
}
...
...
@@ -556,7 +479,7 @@ class CustomAllreduce {
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
a template instantiation:
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *,
int, int, int);
* template void
vllm::
CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *,
int, int, int);
*/
}
// namespace vllm
csrc/custom_all_reduce_test.cu
View file @
f721096d
...
...
@@ -92,7 +92,7 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
template
<
typename
T
>
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
int
data_size
)
{
int
data_size
,
bool
performance_test
)
{
T
*
result
;
cudaStream_t
stream
;
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
...
...
@@ -101,7 +101,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
data_handles
[
8
];
vllm
::
Metadata
*
buffer
;
vllm
::
Signal
*
buffer
;
T
*
self_data_copy
;
/**
* Allocate IPC buffer
...
...
@@ -115,9 +115,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
* convenience.
*/
CUDACHECK
(
cudaMalloc
(
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Metadata
)));
CUDACHECK
(
cudaMemset
(
buffer
,
0
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Metadata
)));
cudaMalloc
(
&
buffer
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
CUDACHECK
(
cudaMemset
(
buffer
,
0
,
2
*
data_size
*
sizeof
(
T
)
+
sizeof
(
vllm
::
Signal
)));
CUDACHECK
(
cudaMalloc
(
&
self_data_copy
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaIpcGetMemHandle
(
&
self_data_handle
,
buffer
));
...
...
@@ -133,7 +133,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
offsets
,
myRank
);
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Metadata
)
+
data_size
*
sizeof
(
T
));
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
{
std
::
vector
<
std
::
string
>
handles
;
...
...
@@ -143,8 +143,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
}
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
sizeof
(
vllm
::
Metadata
)
+
data_size
*
sizeof
(
T
));
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
...
...
@@ -169,81 +169,112 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
}
else
{
ncclDtype
=
ncclFloat
;
}
double
*
nccl_result
,
*
my_result
;
CUDACHECK
(
cudaMallocHost
(
&
nccl_result
,
data_size
*
sizeof
(
double
)));
CUDACHECK
(
cudaMallocHost
(
&
my_result
,
data_size
*
sizeof
(
double
)));
if
(
performance_test
)
{
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
constexpr
int
warmup_iters
=
5
;
constexpr
int
num_iters
=
100
;
// warmup
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
NCCLCHECK
(
ncclAllReduce
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
NCCLCHECK
(
ncclAllReduce
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
float
allreduce_ms
=
0
;
cudaEventElapsedTime
(
&
allreduce_ms
,
start
,
stop
);
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
constexpr
int
warmup_iters
=
5
;
constexpr
int
num_iters
=
25
;
// warmup
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
NCCLCHECK
(
ncclAllReduce
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
NCCLCHECK
(
ncclAllReduce
(
result
,
result
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
float
allreduce_ms
=
0
;
cudaEventElapsedTime
(
&
allreduce_ms
,
start
,
stop
);
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
// warm up
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
float
duration_ms
=
0
;
cudaEventElapsedTime
(
&
duration_ms
,
start
,
stop
);
if
(
myRank
==
0
)
printf
(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
"time:%.2fus
\n
"
,
myRank
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
,
duration_ms
*
1e3
/
num_iters
,
allreduce_ms
*
1e3
/
num_iters
);
dummy_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
// warm up
for
(
int
i
=
0
;
i
<
warmup_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
}
CUDACHECK
(
cudaEventRecord
(
start
,
stream
));
for
(
int
i
=
0
;
i
<
num_iters
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
}
CUDACHECK
(
cudaEventRecord
(
stop
,
stream
));
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
// And wait for all the queued up work to complete
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
float
duration_ms
=
0
;
cudaEventElapsedTime
(
&
duration_ms
,
start
,
stop
);
if
(
myRank
==
0
)
printf
(
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
"time:%.2fus
\n
"
,
myRank
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
,
duration_ms
*
1e3
/
num_iters
,
allreduce_ms
*
1e3
/
num_iters
);
NCCLCHECK
(
ncclAllReduce
(
self_data_copy
,
self_data
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
// And wait for all the queued up work to complete
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
double
*
nccl_result
,
*
my_result
;
CUDACHECK
(
cudaMallocHost
(
&
nccl_result
,
data_size
*
sizeof
(
double
)));
CUDACHECK
(
cudaMallocHost
(
&
my_result
,
data_size
*
sizeof
(
double
)));
NCCLCHECK
(
ncclAllReduce
(
self_data_copy
,
self_data
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data
,
result
,
nccl_result
,
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data
,
result
,
nccl_result
,
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
for
(
unsigned
long
j
=
0
;
j
<
data_size
;
j
++
)
{
auto
diff
=
abs
(
nccl_result
[
j
]
-
my_result
[
j
]);
if
(
diff
>=
1e-2
)
{
printf
(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
break
;
for
(
unsigned
long
j
=
0
;
j
<
data_size
;
j
++
)
{
auto
diff
=
abs
(
nccl_result
[
j
]
-
my_result
[
j
]);
if
(
diff
>=
4e-2
)
{
printf
(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
break
;
}
}
}
long
double
nccl_diffs
=
0.0
;
long
double
my_diffs
=
0.0
;
for
(
int
j
=
0
;
j
<
data_size
;
j
++
)
{
nccl_diffs
+=
abs
(
nccl_result
[
j
]
-
ground_truth
[
j
]);
my_diffs
+=
abs
(
my_result
[
j
]
-
ground_truth
[
j
]);
}
if
(
myRank
==
0
)
std
::
cout
<<
"average abs diffs: nccl: "
<<
nccl_diffs
/
data_size
<<
" me: "
<<
my_diffs
/
data_size
<<
std
::
endl
;
}
else
{
for
(
int
i
=
0
;
i
<
100
;
i
++
)
{
fa
.
allreduce
<
T
>
(
stream
,
self_data
,
result
,
data_size
,
threads
,
block_limit
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
NCCLCHECK
(
ncclAllReduce
(
self_data
,
self_data_copy
,
data_size
,
ncclDtype
,
ncclSum
,
comm
,
stream
));
convert_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
self_data_copy
,
result
,
nccl_result
,
my_result
,
data_size
);
CUDACHECK
(
cudaStreamSynchronize
(
stream
));
long
double
nccl_diffs
=
0.0
;
long
double
my_diffs
=
0.0
;
for
(
int
j
=
0
;
j
<
data_size
;
j
++
)
{
nccl_diffs
+=
abs
(
nccl_result
[
j
]
-
ground_truth
[
j
]);
my_diffs
+=
abs
(
my_result
[
j
]
-
ground_truth
[
j
]);
for
(
unsigned
long
j
=
0
;
j
<
data_size
;
j
++
)
{
auto
diff
=
abs
(
nccl_result
[
j
]
-
my_result
[
j
]);
if
(
diff
>=
4e-2
)
{
printf
(
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f
\n
"
,
myRank
,
j
,
nccl_result
[
j
],
my_result
[
j
],
ground_truth
[
j
]);
break
;
}
}
}
if
(
myRank
==
0
)
printf
(
"Test passed: nGPUs:%d, sz (kb): %d, %d, %d
\n
"
,
nRanks
,
data_size
*
sizeof
(
T
)
/
1024
,
threads
,
block_limit
);
// long double nccl_diffs = 0.0;
// long double my_diffs = 0.0;
// for (int j = 0; j < data_size; j++) {
// nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
// my_diffs += abs(my_result[j] - ground_truth[j]);
// }
// if (myRank == 0)
// std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
// << " me: " << my_diffs / data_size << std::endl;
}
if
(
myRank
==
0
)
std
::
cout
<<
"average abs diffs: nccl: "
<<
nccl_diffs
/
data_size
<<
" me: "
<<
my_diffs
/
data_size
<<
std
::
endl
;
CUDACHECK
(
cudaFree
(
result
));
CUDACHECK
(
cudaFree
(
self_data_copy
));
...
...
@@ -269,14 +300,15 @@ int main(int argc, char **argv) {
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
bool
performance_test
=
true
;
cudaProfilerStart
();
// for (int threads : {256, 512}) {
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
// }
// }
for
(
int
sz
=
512
;
sz
<=
(
32
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
50
);
for
(
int
sz
=
512
;
sz
<=
(
8
<<
20
);
sz
*=
2
)
{
run
<
half
>
(
myRank
,
nRanks
,
comm
,
512
,
36
,
sz
+
8
*
47
,
performance_test
);
}
cudaProfilerStop
();
...
...
vllm/config.py
View file @
f721096d
...
...
@@ -506,15 +506,6 @@ class ParallelConfig:
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
# FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel.
if
not
self
.
disable_custom_all_reduce
and
self
.
world_size
>
1
:
self
.
disable_custom_all_reduce
=
True
logger
.
info
(
"Custom all-reduce kernels are temporarily disabled due to "
"stability issues. We will re-enable them once the issues are "
"resolved."
)
class
SchedulerConfig
:
"""Scheduler configuration.
...
...
vllm/entrypoints/llm.py
View file @
f721096d
...
...
@@ -83,7 +83,7 @@ class LLM:
swap_space
:
int
=
4
,
enforce_eager
:
bool
=
False
,
max_context_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
Fals
e
,
disable_custom_all_reduce
:
bool
=
Tru
e
,
**
kwargs
,
)
->
None
:
if
"disable_log_stats"
not
in
kwargs
:
...
...
vllm/model_executor/parallel_utils/custom_all_reduce.py
View file @
f721096d
...
...
@@ -37,16 +37,23 @@ def init_custom_ar() -> None:
logger
.
warn
(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly."
,
world_size
,
"
disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
_SUPPORTED_WORLD_SIZES
))
return
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warn
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly."
)
" capability
or P2P test failed
. To silence this warning, specify"
"
disable_custom_all_reduce=True explicitly."
)
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
)
full_nvlink
=
_is_full_nvlink
(
rank
,
world_size
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warn
(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
,
full_nvlink
)
def
begin_capture
()
->
None
:
...
...
@@ -134,18 +141,48 @@ def _is_full_nvlink(rank, world_size):
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
num_dev
=
torch
.
cuda
.
device_count
()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if
num_dev
<
world_size
:
logger
.
warn
(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set."
)
return
False
for
i
in
range
(
world_size
):
if
i
==
rank
:
continue
if
not
torch
.
cuda
.
can_device_access_peer
(
rank
,
i
):
return
False
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
if
not
_can_actually_p2p
(
rank
,
i
):
return
False
return
True
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def
_can_actually_p2p
(
idx_a
,
idx_b
):
dev_i
=
f
"cuda:
{
idx_a
}
"
dev_j
=
f
"cuda:
{
idx_b
}
"
a
=
torch
.
randn
(
5
,
device
=
dev_i
)
+
123.0
b
=
a
.
to
(
dev_j
)
c
=
b
.
to
(
dev_i
)
return
torch
.
all
(
a
==
c
)
class
CustomAllreduce
:
# max_size: max supported allreduce size
def
__init__
(
self
,
rank
,
world_size
,
max_size
=
8192
*
1024
)
->
None
:
def
__init__
(
self
,
rank
,
world_size
,
full_nvlink
,
max_size
=
8192
*
1024
)
->
None
:
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
...
...
@@ -167,11 +204,10 @@ class CustomAllreduce:
self
.
max_size
=
max_size
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
_is_
full_nvlink
(
rank
,
world_size
)
self
.
full_nvlink
=
full_nvlink
self
.
_ptr
=
custom_ar
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
fast_cond
=
self
.
full_nvlink
or
world_size
<=
2
self
.
register_buffer
(
self
.
buffer
)
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment