Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6192e9b8
Unverified
Commit
6192e9b8
authored
Nov 06, 2024
by
Hanzhi Zhou
Committed by
GitHub
Nov 06, 2024
Browse files
[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)
Signed-off-by:
Hanzhi Zhou
<
hanzhi713@gmail.com
>
parent
d7263a1b
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
218 additions
and
260 deletions
+218
-260
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+61
-58
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+45
-42
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+13
-11
csrc/ops.h
csrc/ops.h
+9
-13
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-15
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+2
-2
tools/profiler/visualize_layerwise_profile.py
tools/profiler/visualize_layerwise_profile.py
+16
-16
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-17
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+54
-86
No files found.
csrc/custom_all_reduce.cu
View file @
6192e9b8
...
...
@@ -5,32 +5,29 @@
#include "custom_all_reduce.cuh"
// fake pointer type, must match fptr_t type in ops.h
// Fake pointer type, must match fptr_t type in ops.h.
// We use this type alias to indicate when pointers are passed in as int64_t.
using
fptr_t
=
int64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
)
{
int
world_size
=
offset
s
.
size
();
int
world_size
=
fake_ipc_ptr
s
.
size
();
if
(
world_size
>
8
)
throw
std
::
invalid_argument
(
"world size > 8 is not supported"
);
if
(
world_size
%
2
!=
0
)
throw
std
::
invalid_argument
(
"Odd num gpus is not supported for now"
);
if
(
world_size
!=
handles
.
size
())
throw
std
::
invalid_argument
(
"handles length should equal to offsets length"
);
if
(
rank
<
0
||
rank
>=
world_size
)
throw
std
::
invalid_argument
(
"invalid rank passed in"
);
cudaIpcMemHandle_t
ipc_handle
s
[
8
];
vllm
::
Signal
*
ipc_ptr
s
[
8
];
for
(
int
i
=
0
;
i
<
world_size
;
i
++
)
{
std
::
memcpy
(
&
ipc_handles
[
i
],
handles
[
i
].
data
(),
sizeof
(
cudaIpcMemHandle_t
)
);
ipc_ptrs
[
i
]
=
reinterpret_cast
<
vllm
::
Signal
*>
(
fake_ipc_ptrs
[
i
]
);
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
()
,
rank_data
.
numel
(),
ipc_handles
,
offsets
,
rank
,
full_nvlink
);
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
ipc_ptrs
,
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
rank
,
world_size
,
full_nvlink
);
}
/**
...
...
@@ -55,26 +52,48 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t
.
numel
()
*
t
.
element_size
());
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
cudaStream_t
stream
)
{
/**
* Performs an out-of-place allreduce and stores result in out.
*
* If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered.
* Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first
* copied into _reg_buffer.
*/
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
_reg_buffer
,
int64_t
reg_buffer_sz_bytes
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_is_weak_contiguous
(
inp
));
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
reg_buffer
=
reinterpret_cast
<
void
*>
(
_reg_buffer
);
if
(
reg_buffer
)
{
TORCH_CHECK_LE
(
input_size
,
reg_buffer_sz_bytes
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
,
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
}
else
{
reg_buffer
=
inp
.
data_ptr
();
}
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()
),
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
reg_buffer
),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()
),
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
reg_buffer
),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()
),
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
reg_buffer
),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
...
...
@@ -85,57 +104,41 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
}
}
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
_all_reduce
(
_fa
,
inp
,
out
,
stream
);
}
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
input_size
=
inp
.
numel
()
*
inp
.
element_size
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
numel
(),
out
.
numel
());
TORCH_CHECK
(
input_size
<=
reg_buffer
.
numel
()
*
reg_buffer
.
element_size
(),
"registered buffer is too small to contain the input"
);
AT_CUDA_CHECK
(
cudaMemcpyAsync
(
reg_buffer
.
data_ptr
(),
inp
.
data_ptr
(),
input_size
,
cudaMemcpyDeviceToDevice
,
stream
));
_all_reduce
(
_fa
,
reg_buffer
,
out
,
stream
);
}
void
dispose
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
delete
fa
;
delete
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
}
int64_t
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
)
{
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
fptr_t
>&
fake_ipc_ptrs
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_buffer
(
handles
,
offsets
,
t
.
data_ptr
());
TORCH_CHECK
(
fake_ipc_ptrs
.
size
()
==
fa
->
world_size_
);
void
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
fake_ipc_ptrs
.
size
();
i
++
)
{
ipc_ptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
fake_ipc_ptrs
[
i
]);
}
fa
->
register_buffer
(
ipc_ptrs
);
}
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
// Use vector<int64_t> to represent byte data for python binding compatibility.
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
[
handle_bytes
,
offsets
]
=
fa
->
get_graph_buffer_ipc_meta
();
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
);
auto
handles
=
torch
::
empty
({
static_cast
<
int64_t
>
(
handle_bytes
.
size
())},
options
);
std
::
memcpy
(
handles
.
data_ptr
(),
handle_bytes
.
data
(),
handle_bytes
.
size
());
return
{
handles
,
std
::
move
(
offsets
)};
auto
[
handle
,
offsets
]
=
fa
->
get_graph_buffer_ipc_meta
();
std
::
vector
<
int64_t
>
bytes
(
handle
.
begin
(),
handle
.
end
());
return
std
::
make_tuple
(
bytes
,
offsets
);
}
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
// Use vector<int64_t> to represent byte data for python binding compatibility.
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_graph_buffers
(
handles
,
offsets
);
std
::
vector
<
std
::
string
>
bytes
;
bytes
.
reserve
(
handles
.
size
());
for
(
int
i
=
0
;
i
<
handles
.
size
();
i
++
)
{
bytes
.
emplace_back
(
handles
[
i
].
begin
(),
handles
[
i
].
end
());
}
bytes
.
reserve
(
handles
.
size
());
fa
->
register_graph_buffers
(
bytes
,
offsets
);
}
csrc/custom_all_reduce.cuh
View file @
6192e9b8
...
...
@@ -285,46 +285,52 @@ class CustomAllreduce {
int
world_size_
;
bool
full_nvlink_
;
// below are device pointers
RankSignals
sg_
;
// Stores an map from a pointer to its peer pointters from all ranks.
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
// stores the registered device pointers from all ranks
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph capture
// time. Therefore, during capture, we increment the rank data pointer and use
// that as the argument to the kernel. The kernel arguments are stored in
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
// memory pointed to by the pointers in graph_unreg_buffers_ when
// the IPC handles are exchanged between ranks.
//
// The overall process looks like this:
// 1. Graph capture.
// 2. Each rank obtains the IPC handles for each addresses used during cuda
// graph capture using get_graph_buffer_ipc_meta.
// 3. (In Python) all gather the IPC handles.
// 4. Obtain the peer pointers by opening the IPC handles, and store them in
// the rank data array at corresponding positions.
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
CustomAllreduce
(
Signal
**
signals
,
void
*
rank_data
,
size_t
rank_data_sz
,
int
rank
,
int
world_size
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()
),
world_size_
(
world_
size
),
full_nvlink_
(
full_nvlink
),
self_sg_
(
meta
),
self_sg_
(
signals
[
rank
]
),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Signal
*
rank_sg
;
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
handle
+=
offsets
[
i
];
rank_sg
=
(
Signal
*
)
handle
;
}
else
{
rank_sg
=
self_sg_
;
}
sg_
.
signals
[
i
]
=
rank_sg
;
sg_
.
signals
[
i
]
=
signals
[
i
];
}
}
...
...
@@ -341,11 +347,10 @@ class CustomAllreduce {
return
it
->
second
;
}
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
()
{
std
::
pair
<
std
::
string
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
()
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
auto
handle_sz
=
sizeof
(
cudaIpcMemHandle_t
);
std
::
vector
<
uint8_t
>
handles
(
handle_sz
*
num_buffers
,
0
);
std
::
string
handles
(
handle_sz
*
num_buffers
,
static_cast
<
char
>
(
0
)
);
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
graph_unreg_buffers_
[
i
];
...
...
@@ -370,26 +375,22 @@ class CustomAllreduce {
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
/**
* Register already-shared IPC pointers.
*/
void
register_buffer
(
void
**
ptrs
)
{
check_rank_data_capacity
();
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
handle
+=
offsets
[
i
];
data
.
ptrs
[
i
]
=
handle
;
}
else
{
data
.
ptrs
[
i
]
=
self
;
}
data
.
ptrs
[
i
]
=
ptrs
[
i
];
}
auto
d_data
=
d_rank_data_base_
++
;
CUDACHECK
(
cudaMemcpy
(
d_data
,
&
data
,
sizeof
(
RankData
),
cudaMemcpyHostToDevice
));
buffers_
[
self
]
=
d_data
;
buffers_
[
ptrs
[
rank_
]
]
=
d_data
;
}
//
n
ote: when registering graph buffers, we intentionally choose to not
//
N
ote: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
...
...
@@ -424,11 +425,13 @@ class CustomAllreduce {
}
/**
* This is the result after careful grid search. Using 36 blocks give the best
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
* Not quite sure the underlying reason, but my guess is that too many SMs
* will cause contention on NVLink bus.
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
...
...
csrc/custom_all_reduce_test.cu
View file @
6192e9b8
...
...
@@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
offsets
,
myRank
);
vllm
::
Signal
*
ipc_ptrs
[
8
];
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
if
(
i
==
myRank
)
ipc_ptrs
[
i
]
=
buffer
;
else
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptrs
[
i
],
data_handles
[
i
],
cudaIpcMemLazyEnablePeerAccess
));
}
vllm
::
CustomAllreduce
fa
(
ipc_ptrs
,
rank_data
,
rank_data_sz
,
myRank
,
nRanks
);
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
{
std
::
vector
<
std
::
string
>
handles
;
handles
.
reserve
(
nRanks
);
void
*
data
[
8
];
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
data
[
i
]
=
((
char
*
)
ipc_ptrs
[
i
])
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
);
}
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
fa
.
register_buffer
(
data
);
}
double
*
ground_truth
;
...
...
csrc/ops.h
View file @
6192e9b8
...
...
@@ -199,20 +199,16 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
fptr_t
reg_buffer
,
int64_t
reg_buffer_sz_bytes
);
void
dispose
(
fptr_t
_fa
);
int64_t
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
tuple
<
torch
::
Tensor
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
void
register_buffer
(
fptr_t
_fa
,
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
);
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
csrc/torch_bindings.cpp
View file @
6192e9b8
...
...
@@ -411,27 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_custom_ar
),
custom_ar
)
{
// Custom all-reduce kernels
custom_ar
.
def
(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int"
);
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int"
);
custom_ar
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
custom_ar
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
custom_ar
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
custom_ar
.
def
(
"all_reduce
_unreg
(int fa, Tensor inp, Tensor reg_buffer,
Tensor! out) ->
"
"()"
);
custom_ar
.
impl
(
"all_reduce
_unreg
"
,
torch
::
kCUDA
,
&
all_reduce
_unreg
);
"all_reduce(int fa, Tensor inp, Tensor
! out, int
reg_buffer, "
"
int reg_buffer_sz_bytes) ->
()"
);
custom_ar
.
impl
(
"all_reduce"
,
torch
::
kCUDA
,
&
all_reduce
);
custom_ar
.
def
(
"dispose"
,
&
dispose
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
);
custom_ar
.
def
(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()"
);
custom_ar
.
impl
(
"register_buffer"
,
torch
::
kCUDA
,
&
register_buffer
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
}
...
...
tests/distributed/test_custom_all_reduce.py
View file @
6192e9b8
...
...
@@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
inp
=
torch
.
ones
(
sz
,
dtype
=
torch
.
float32
,
device
=
device
)
out
=
inp
for
_
in
range
(
num_communication
):
out
=
fa
.
all_reduce
_unreg
(
out
)
out
=
fa
.
all_reduce
(
out
,
registered
=
False
)
torch
.
testing
.
assert_close
(
out
,
inp
*
(
tp_size
**
num_communication
))
inp
=
torch
.
ones
(
sz
*
4
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
out
=
inp
for
_
in
range
(
num_communication
):
out
=
fa
.
all_reduce
_unreg
(
out
)
out
=
fa
.
all_reduce
(
out
,
registered
=
False
)
torch
.
testing
.
assert_close
(
out
,
inp
*
(
tp_size
**
num_communication
))
...
...
tools/profiler/visualize_layerwise_profile.py
View file @
6192e9b8
...
...
@@ -196,8 +196,8 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def
is_cross_device_reduce_2stage
(
op_name
:
str
):
return
"cross_device_reduce_2stage"
in
op_name
def
is_custom_ar_all_reduce
_unreg
(
op_name
:
str
):
return
"_C_custom_ar::all_reduce
_unreg
"
in
op_name
def
is_custom_ar_all_reduce
(
op_name
:
str
):
return
"_C_custom_ar::all_reduce"
in
op_name
def
is_reduce_kernel
(
op_name
:
str
):
return
"reduce_kernel"
in
op_name
...
...
@@ -246,9 +246,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
filter
(
lambda
x
:
is_cross_device_reduce_2stage
(
x
),
ops
))
ops
=
list
(
filter
(
lambda
x
:
x
not
in
cross_device_reduce_2stage_ops
,
ops
))
custom_ar_all_reduce_
unreg_
ops
=
list
(
filter
(
lambda
x
:
is_custom_ar_all_reduce
_unreg
(
x
),
ops
))
ops
=
list
(
filter
(
lambda
x
:
x
not
in
custom_ar_all_reduce_
unreg_
ops
,
ops
))
custom_ar_all_reduce_ops
=
list
(
filter
(
lambda
x
:
is_custom_ar_all_reduce
(
x
),
ops
))
ops
=
list
(
filter
(
lambda
x
:
x
not
in
custom_ar_all_reduce_ops
,
ops
))
reduce_kernel_ops
=
list
(
filter
(
lambda
x
:
is_reduce_kernel
(
x
),
ops
))
ops
=
list
(
filter
(
lambda
x
:
x
not
in
reduce_kernel_ops
,
ops
))
...
...
@@ -289,21 +289,21 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
if
len
(
cross_device_reduce_2stage_ops
):
trace_df
[
'cross_device_reduce_2stage_ops'
]
=
trace_df
[
cross_device_reduce_2stage_ops
].
agg
(
"sum"
,
axis
=
1
)
if
len
(
custom_ar_all_reduce_
unreg_
ops
):
trace_df
[
'custom_ar_all_reduce_
unreg_
ops'
]
=
trace_df
[
custom_ar_all_reduce_
unreg_
ops
].
agg
(
"sum"
,
axis
=
1
)
if
len
(
custom_ar_all_reduce_ops
):
trace_df
[
'custom_ar_all_reduce_ops'
]
=
trace_df
[
custom_ar_all_reduce_ops
].
agg
(
"sum"
,
axis
=
1
)
if
len
(
reduce_kernel_ops
):
trace_df
[
'reduce_kernel_ops'
]
=
trace_df
[
reduce_kernel_ops
].
agg
(
"sum"
,
axis
=
1
)
trace_df
.
drop
(
attention_ops
+
quant_ops
+
gemm_ops
+
rms_norm_ops
+
vocab_embed
_ops
+
mem_ops
+
elementwise_ops
+
nccl_all_reduce_ops
+
nccl_gather_ops
+
nccl_broadcast_ops
+
nccl_other_ops
+
cross_device_reduce_1stage_ops
+
cross_device_reduce_2stage_ops
+
custom_ar_all_reduce_
unreg_
ops
+
reduce_kernel_ops
,
axis
=
1
,
inplace
=
True
)
trace_df
.
drop
(
attention_ops
+
quant_ops
+
gemm_ops
+
rms_norm_ops
+
vocab_embed_ops
+
mem_ops
+
elementwise
_ops
+
nccl_all_reduce_ops
+
nccl_gather_ops
+
nccl_broadcast_ops
+
nccl_other_ops
+
cross_device_reduce_1stage_ops
+
cross_device_reduce_2stage_ops
+
custom_ar_all_reduce_ops
+
reduce_kernel_ops
,
axis
=
1
,
inplace
=
True
)
return
trace_df
...
...
vllm/_custom_ops.py
View file @
6192e9b8
...
...
@@ -912,20 +912,16 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar
def
init_custom_ar
(
meta
:
torch
.
Tensor
,
rank_data
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
],
rank
:
int
,
full_nvlink
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
meta
,
rank_data
,
handles
,
offsets
,
rank
,
full_nvlink
)
def
init_custom_ar
(
ipc_tensors
:
List
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full_nvlink
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
def
all_reduce_unreg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
reg_buffer
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_unreg
(
fa
,
inp
,
reg_buffer
,
out
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
reg_buffer_sz_bytes
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce
(
fa
,
inp
,
out
,
reg_buffer
,
reg_buffer_sz_bytes
)
def
dispose
(
fa
:
int
)
->
None
:
...
...
@@ -936,16 +932,15 @@ def meta_size() -> int:
return
torch
.
ops
.
_C_custom_ar
.
meta_size
()
def
register_buffer
(
fa
:
int
,
t
:
torch
.
Tensor
,
handles
:
List
[
str
],
offsets
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
t
,
handles
,
offsets
)
def
register_buffer
(
fa
:
int
,
ipc_tensors
:
List
[
int
])
->
None
:
return
torch
.
ops
.
_C_custom_ar
.
register_buffer
(
fa
,
ipc_tensors
)
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
str
],
List
[
int
]]:
def
get_graph_buffer_ipc_meta
(
fa
:
int
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
_C_custom_ar
.
get_graph_buffer_ipc_meta
(
fa
)
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
str
],
def
register_graph_buffers
(
fa
:
int
,
handles
:
List
[
List
[
int
]
],
offsets
:
List
[
List
[
int
]])
->
None
:
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
6192e9b8
import
ctypes
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -147,18 +147,14 @@ class CustomAllreduce:
return
self
.
disabled
=
False
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self
.
meta
=
torch
.
zeros
(
ops
.
meta_size
()
+
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
...
...
@@ -170,16 +166,19 @@ class CustomAllreduce:
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
full_nvlink
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
h
an
dles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
_ptr
=
ops
.
init_custom_ar
(
self
.
meta
_ptrs
,
self
.
rank_data
,
r
an
k
,
self
.
full_nvlink
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer
_ptrs
)
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
List
[
int
]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib
=
CudaRTLibrary
()
pointer
=
lib
.
cudaMalloc
(
size_in_bytes
)
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
...
...
@@ -220,60 +219,24 @@ class CustomAllreduce:
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
data
=
inp
.
untyped_storage
().
_share_cuda_
()
handle
=
data
[
1
]
# https://github.com/pytorch/pytorch/pull/130890 changes
# the binary format of the ipc handle
# it starts from pytorch 2.5
if
len
(
handle
)
>
64
:
assert
len
(
handle
)
==
66
# only support SHAREABLE_HANDLE_VERSION = 1
assert
int
(
handle
[
0
])
==
1
# only support SHAREABLE_CUDA_MALLOC = 'c'
assert
handle
[
1
]
==
ord
(
"c"
)
handle
=
handle
[
2
:]
# TODO: support expandable segment
shard_data
=
(
handle
,
# ipc handle to base ptr
data
[
3
],
# offset of base ptr
)
return
self
.
_gather_ipc_meta
(
shard_data
)
def
_gather_ipc_meta
(
self
,
shard_data
):
# Note: don't use `[[None]] * self.world_size` here
# because it will create a list of the same reference
all_data
:
List
[
Optional
[
Any
]]
=
[[
None
]
for
i
in
range
(
self
.
world_size
)]
all_data
[
self
.
rank
][
0
]
=
shard_data
ranks
=
dist
.
get_process_group_ranks
(
group
=
self
.
group
)
ranks
.
sort
()
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data
=
[[
None
,
None
]
for
_
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
))]
all_data
[
self
.
rank
]
=
[
handle
,
offset
]
ranks
=
sorted
(
dist
.
get_process_group_ranks
(
group
=
self
.
group
))
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles
=
[]
offsets
=
[]
for
i
in
range
(
len
(
all_data
)):
handles
.
append
(
all_data
[
i
][
0
][
0
])
# type: ignore
offsets
.
append
(
all_data
[
i
][
0
][
1
])
# type: ignore
return
handles
,
offsets
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
handles
,
offsets
=
self
.
_get_ipc_meta
(
inp
)
ops
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
handle
,
offset
=
ops
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
# Unpack list of tuples to tuple of lists.
handles
=
[
d
[
0
]
for
d
in
all_data
]
# type: ignore
offsets
=
[
d
[
1
]
for
d
in
all_data
]
# type: ignore
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
...
...
@@ -291,45 +254,50 @@ class CustomAllreduce:
return
inp_size
<
self
.
max_size
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
return
out
# all reduce, assuming inp tensor is NOT IPC registered
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
def
all_reduce
(
self
,
inp
:
torch
.
Tensor
,
*
,
out
:
torch
.
Tensor
=
None
,
registered
:
bool
=
False
):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
ops
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
if
registered
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
0
,
0
)
else
:
ops
.
all_reduce
(
self
.
_ptr
,
inp
,
out
,
self
.
buffer_ptrs
[
self
.
rank
],
self
.
max_size
)
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# when custom allreduce is disabled, this will be None
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
):
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
return
self
.
all_reduce
_reg
(
input
)
return
self
.
all_reduce
(
input
,
registered
=
True
)
else
:
#
i
f warm up, mimic the allocation pattern
#
since custom
allreduce is out-of-place
#
I
f warm up, mimic the allocation pattern
since custom
# allreduce is out-of-place
.
return
torch
.
empty_like
(
input
)
else
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return
self
.
all_reduce_unreg
(
input
)
return
None
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return
self
.
all_reduce
(
input
,
registered
=
False
)
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
def
__del__
(
self
):
self
.
close
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment