Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2d6bccd9
Commit
2d6bccd9
authored
Mar 17, 2025
by
xiabo
Browse files
add custom allreduce
parent
d9e67e78
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
668 additions
and
444 deletions
+668
-444
CMakeLists.txt
CMakeLists.txt
+5
-0
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+41
-0
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+166
-72
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+334
-317
csrc/ops.h
csrc/ops.h
+6
-2
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-2
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+2
-1
tests/utils.py
tests/utils.py
+11
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-0
vllm/config.py
vllm/config.py
+6
-5
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+78
-44
No files found.
CMakeLists.txt
View file @
2d6bccd9
...
@@ -463,6 +463,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -463,6 +463,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# if CUDA endif
# if CUDA endif
endif
()
endif
()
if
(
VLLM_GPU_LANG STREQUAL
"HIP"
)
list
(
APPEND VLLM_EXT_SRC
"csrc/custom_all_reduce.cu"
)
endif
()
message
(
STATUS
"Enabling C extension."
)
message
(
STATUS
"Enabling C extension."
)
define_gpu_extension_target
(
define_gpu_extension_target
(
_C
_C
...
...
csrc/custom_all_reduce.cu
View file @
2d6bccd9
...
@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa,
...
@@ -142,3 +142,44 @@ void register_graph_buffers(fptr_t _fa,
bytes
.
reserve
(
handles
.
size
());
bytes
.
reserve
(
handles
.
size
());
fa
->
register_graph_buffers
(
bytes
,
offsets
);
fa
->
register_graph_buffers
(
bytes
,
offsets
);
}
}
std
::
tuple
<
fptr_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
int64_t
size
)
{
auto
device_index
=
c10
::
cuda
::
current_device
();
at
::
DeviceGuard
device_guard
(
at
::
Device
(
at
::
DeviceType
::
CUDA
,
device_index
));
void
*
buffer
;
cudaStreamCaptureMode
mode
=
cudaStreamCaptureModeRelaxed
;
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
#if defined(USE_ROCM)
// data buffers need to be "uncached" for signal on MI200
AT_CUDA_CHECK
(
hipExtMallocWithFlags
((
void
**
)
&
buffer
,
size
,
hipDeviceMallocUncached
));
#else
AT_CUDA_CHECK
(
cudaMalloc
((
void
**
)
&
buffer
,
size
));
#endif
AT_CUDA_CHECK
(
cudaMemsetAsync
(
buffer
,
0
,
size
,
stream
));
AT_CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
AT_CUDA_CHECK
(
cudaThreadExchangeStreamCaptureMode
(
&
mode
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
torch
::
kCPU
);
auto
handle
=
torch
::
empty
({
static_cast
<
int64_t
>
(
sizeof
(
cudaIpcMemHandle_t
))},
options
);
AT_CUDA_CHECK
(
cudaIpcGetMemHandle
((
cudaIpcMemHandle_t
*
)
handle
.
data_ptr
(),
buffer
));
return
std
::
make_tuple
(
reinterpret_cast
<
fptr_t
>
(
buffer
),
handle
);
}
fptr_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
)
{
void
*
ipc_ptr
;
AT_CUDA_CHECK
(
cudaIpcOpenMemHandle
(
(
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
mem_handle
.
data_ptr
()),
cudaIpcMemLazyEnablePeerAccess
));
return
reinterpret_cast
<
fptr_t
>
(
ipc_ptr
);
}
void
free_shared_buffer
(
fptr_t
buffer
)
{
AT_CUDA_CHECK
(
cudaFree
(
reinterpret_cast
<
void
*>
(
buffer
)));
}
\ No newline at end of file
csrc/custom_all_reduce.cuh
View file @
2d6bccd9
...
@@ -5,6 +5,10 @@
...
@@ -5,6 +5,10 @@
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#if defined(USE_ROCM)
typedef
__hip_bfloat16
nv_bfloat16
;
#endif
#include <iostream>
#include <iostream>
#include <array>
#include <array>
#include <limits>
#include <limits>
...
@@ -12,6 +16,7 @@
...
@@ -12,6 +16,7 @@
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
namespace
vllm
{
#define CUDACHECK(cmd) \
#define CUDACHECK(cmd) \
do { \
do { \
cudaError_t e = cmd; \
cudaError_t e = cmd; \
...
@@ -22,24 +27,38 @@
...
@@ -22,24 +27,38 @@
} \
} \
} while (0)
} while (0)
namespace
vllm
{
// Maximal number of blocks in allreduce kernel.
constexpr
int
kMaxBlocks
=
36
;
constexpr
int
kMaxBlocks
=
36
;
// Default number of blocks in allreduce kernel.
#ifndef USE_ROCM
const
int
defaultBlockLimit
=
36
;
CUpointer_attribute
rangeStartAddrAttr
=
CUDA_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#else
const
int
defaultBlockLimit
=
16
;
hipPointer_attribute
rangeStartAddrAttr
=
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR
;
#endif
// Counter may overflow, but it's fine since unsigned int overflow is
// Counter may overflow, but it's fine since unsigned int overflow is
// well-defined behavior.
// well-defined behavior.
using
FlagType
=
uint32_t
;
using
FlagType
=
uint32_t
;
// Two sets of peer counters are needed for two syncs: starting and ending an
// operation. The reason is that it's possible for peer GPU block to arrive at
// the second sync point while the current GPU block haven't passed the first
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
// waiting for counter. We use alternating counter array to avoid this
// possibility.
struct
Signal
{
struct
Signal
{
alignas
(
128
)
FlagType
self_counter
[
kMaxBlocks
][
8
];
alignas
(
128
)
FlagType
start
[
kMaxBlocks
][
8
];
// Two sets of peer counters are needed for two syncs. The reason is that
alignas
(
128
)
FlagType
end
[
kMaxBlocks
][
8
];
// it's possible for peer GPU block to arrive at the second sync point while
alignas
(
128
)
FlagType
_flag
[
kMaxBlocks
];
// incremental flags for each rank
// the current GPU block haven't passed the first sync point. Thus, peer GPU
// may write counter+1 while current GPU is busy waiting for counter. We use
// alternating counter array to avoid this possibility.
alignas
(
128
)
FlagType
peer_counter
[
2
][
kMaxBlocks
][
8
];
};
};
struct
__align__
(
16
)
RankData
{
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
const
void
*
ptrs
[
8
];
};
};
struct
__align__
(
16
)
RankSignals
{
struct
__align__
(
16
)
RankSignals
{
...
@@ -134,27 +153,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
...
@@ -134,27 +153,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
}
}
}
}
#if !defined(USE_ROCM)
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
static
DINLINE
void
st_flag_release
(
FlagType
*
flag_addr
,
FlagType
flag
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
asm
volatile
(
"st.release.sys.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
"l"
(
flag_addr
));
#else
#else
asm
volatile
(
"membar.sys; st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
asm
volatile
(
"membar.sys; st.volatile.global.u32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
"l"
(
flag_addr
));
#endif
#endif
}
}
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
static
DINLINE
FlagType
ld_flag_acquire
(
FlagType
*
flag_addr
)
{
FlagType
flag
;
FlagType
flag
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
asm
volatile
(
"ld.acquire.sys.global.u32 %0, [%1];"
:
"=r"
(
flag
)
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
:
"l"
(
flag_addr
));
#else
#else
asm
volatile
(
"ld.volatile.global.u32 %0, [%1]; membar.gl;"
asm
volatile
(
"ld.volatile.global.u32 %0, [%1]; membar.gl;"
:
"=r"
(
flag
)
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
:
"l"
(
flag_addr
));
#endif
#endif
return
flag
;
return
flag
;
}
}
...
@@ -170,37 +191,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
...
@@ -170,37 +191,108 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
return
flag
;
return
flag
;
}
}
// is_start: whether this is the very first synchronization barrier.
// This function is meant to be used as the first synchronization in the all
// need_fence: whether a memory fence is needed. If true, a release-acquire
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// semantic is used to enforce memory access order before and after this
// prior memory accesses. Note: volatile writes will not be reordered against
// barrier.
// other volatile writes.
template
<
int
ngpus
,
bool
is_start
,
bool
need_fence
=
false
>
template
<
int
ngpus
>
DINLINE
void
multi_gpu_barrier
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
constexpr
(
!
is_start
)
__syncthreads
();
static_assert
(
!
(
is_start
&&
need_fence
));
// Start barrier shouldn't need fence.
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// Increment the counter. Technically we only need one counter, but we use
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
];
// multiple per block to eliminate the need to share the counter via smem.
auto
self_counter_ptr
=
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
];
auto
val
=
self_sg
->
self_counter
[
blockIdx
.
x
][
threadIdx
.
x
]
+=
1
;
// Write the expected counter value to peer and wait for correct value
// from peer.
st_flag_volatile
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
// This function is meant to be used as the second or the final
// synchronization barrier in the all reduce kernel. If it's the final
// synchronization barrier, we don't need to make any visibility guarantees
// for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
auto
peer_counter_ptr
=
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
];
auto
self_counter_ptr
=
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
];
// Write the expected counter value to peer and wait for correct value from
// Write the expected counter value to peer and wait for correct value from
// peer.
// peer.
auto
peer_counter_ptr
=
if
constexpr
(
!
final_sync
)
{
&
sg
.
signals
[
threadIdx
.
x
]
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
rank
];
st_flag_release
(
peer_counter_ptr
,
flag
);
auto
self_counter_ptr
=
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
flag
);
&
self_sg
->
peer_counter
[
val
%
2
][
blockIdx
.
x
][
threadIdx
.
x
];
if
constexpr
(
need_fence
)
{
st_flag_release
(
peer_counter_ptr
,
val
);
while
(
ld_flag_acquire
(
self_counter_ptr
)
!=
val
);
}
else
{
}
else
{
st_flag_volatile
(
peer_counter_ptr
,
val
);
st_flag_volatile
(
peer_counter_ptr
,
flag
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
val
);
while
(
ld_flag_volatile
(
self_counter_ptr
)
!=
flag
);
}
}
}
}
if
constexpr
(
is_start
||
need_fence
)
__syncthreads
();
if
constexpr
(
!
final_sync
)
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
}
#else
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
// flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
// __ATOMIC_RELAXED,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
],
flag
,
__ATOMIC_RELAXED
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
],
__ATOMIC_RELAXED
)
<
flag
);
}
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
uint32_t
flag
=
self_sg
->
_flag
[
blockIdx
.
x
]
+
1
;
if
(
threadIdx
.
x
<
ngpus
)
{
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
// __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
// flag,
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
// __MEMORY_SCOPE_SYSTEM);
// // wait until we got true from all ranks
// while (
// __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
// final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
// __MEMORY_SCOPE_DEVICE) < flag);
__atomic_store_n
(
&
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
],
flag
,
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_RELEASE
);
// wait until we got true from all ranks
while
(
__atomic_load_n
(
&
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
],
final_sync
?
__ATOMIC_RELAXED
:
__ATOMIC_ACQUIRE
)
<
flag
);
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
// use one thread to update flag
if
(
threadIdx
.
x
==
0
)
self_sg
->
_flag
[
blockIdx
.
x
]
=
flag
;
}
#endif
template
<
typename
P
,
int
ngpus
,
typename
A
>
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
...
@@ -220,13 +312,13 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -220,13 +312,13 @@ __global__ void __launch_bounds__(512, 1)
// note: we don't reorder the address so the accumulation order is the same
// note: we don't reorder the address so the accumulation order is the same
// for all ranks, ensuring bitwise identical results
// for all ranks, ensuring bitwise identical results
auto
dp
=
*
_dp
;
auto
dp
=
*
_dp
;
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// do the actual reduction
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
}
multi_gpu_barrier
<
ngpus
,
fals
e
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
,
tru
e
>
(
sg
,
self_sg
,
rank
);
}
}
template
<
typename
P
>
template
<
typename
P
>
...
@@ -255,18 +347,20 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -255,18 +347,20 @@ __global__ void __launch_bounds__(512, 1)
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
}
auto
tmp_out
=
tmps
[
0
];
auto
tmp_out
=
tmps
[
0
];
multi_gpu_barrier
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
start_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 1: reduce scatter
// stage 1: reduce scatter
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
for
(
int
idx
=
start
+
tid
;
idx
<
end
;
idx
+=
stride
)
{
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
tmp_out
[
idx
-
start
]
=
packed_reduce
<
P
,
ngpus
,
A
>
(
ptrs
,
idx
);
}
}
multi_gpu_barrier
<
ngpus
,
false
,
true
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
>
(
sg
,
self_sg
,
rank
);
// stage 2: allgather. Note: it's important to match the tid between
// stage 2: allgather. Note: it's important to match the tid between
// the two stages, because visibility across devices is only guaranteed
// the two stages, because visibility across devices is only guaranteed
// between threads that have the same tid. If thread i computes the sum of
// between threads that have the same tid. If thread i computes the sum of
// start + i in the first stage, then thread i also gathers start + i from all
// start + i in the first stage, then thread i also gathers start + i from
// ranks.
// all ranks.
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
for
(
int
idx
=
tid
;
idx
<
largest_part
;
idx
+=
stride
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
...
@@ -290,18 +384,18 @@ class CustomAllreduce {
...
@@ -290,18 +384,18 @@ class CustomAllreduce {
bool
full_nvlink_
;
bool
full_nvlink_
;
RankSignals
sg_
;
RankSignals
sg_
;
// Stores an map from a pointer to its peer point
t
ers from all ranks.
// Stores an map from a pointer to its peer pointers from all ranks.
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
Signal
*
self_sg_
;
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
// For cuda graph to work, all kernel arguments must be fixed during graph
// For cuda graph to work, all kernel arguments must be fixed during graph
// capture time. However, the peer pointers are not known during graph
capture
// capture time. However, the peer pointers are not known during graph
// time. Therefore, during capture, we increment the rank data
pointer and use
//
capture
time. Therefore, during capture, we increment the rank data
// that as the argument to the kernel. The kernel arguments
are stored in
//
pointer and use
that as the argument to the kernel. The kernel arguments
// graph_unreg_buffers_. The actual peer pointers will be
filled in at the
//
are stored in
graph_unreg_buffers_. The actual peer pointers will be
// memory pointed to by the pointers in
graph_unreg_buffers_ when
//
filled in at the
memory pointed to by the pointers in
// the IPC handles are exchanged between ranks.
//
graph_unreg_buffers_ when
the IPC handles are exchanged between ranks.
//
//
// The overall process looks like this:
// The overall process looks like this:
// 1. Graph capture.
// 1. Graph capture.
...
@@ -319,8 +413,9 @@ class CustomAllreduce {
...
@@ -319,8 +413,9 @@ class CustomAllreduce {
* Signals are an array of ipc-enabled buffers from all ranks.
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* The first section is for allreduce synchronization, and the second
* is for storing the intermediate results required by some allreduce algos.
* section is for storing the intermediate results required by some
* allreduce algos.
*
*
* Note: this class does not own any device memory. Any required buffers
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
* are passed in from the constructor.
...
@@ -361,8 +456,7 @@ class CustomAllreduce {
...
@@ -361,8 +456,7 @@ class CustomAllreduce {
void
*
base_ptr
;
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// note: must share the base address of each allocation, or we get wrong
// address
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
if
(
cuPointerGetAttribute
(
&
base_ptr
,
rangeStartAddrAttr
,
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR
,
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
cudaIpcGetMemHandle
(
CUDACHECK
(
cudaIpcGetMemHandle
(
...
@@ -396,11 +490,11 @@ class CustomAllreduce {
...
@@ -396,11 +490,11 @@ class CustomAllreduce {
// Note: when registering graph buffers, we intentionally choose to not
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the
remote
// addresses, they will be registered again. This is to account for the
// possibility of different allocation patterns between ranks. For
example,
//
remote
possibility of different allocation patterns between ranks. For
// rank 1 may get the same input address for the second allreduce,
but rank 2
//
example,
rank 1 may get the same input address for the second allreduce,
// got a different address. IPC handles have internal reference
counting
//
but rank 2
got a different address. IPC handles have internal reference
// mechanism so overhead should be small.
//
counting
mechanism so overhead should be small.
void
register_graph_buffers
(
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
...
@@ -431,15 +525,15 @@ class CustomAllreduce {
...
@@ -431,15 +525,15 @@ class CustomAllreduce {
/**
/**
* Performs allreduce, assuming input has already been registered.
* Performs allreduce, assuming input has already been registered.
*
*
* Block and grid default configs are results after careful grid search.
Using
* Block and grid default configs are results after careful grid search.
* 36 blocks give the best or close to the best runtime on the devices
I
*
Using
36 blocks give the best or close to the best runtime on the devices
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
only
*
I
tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
* take a small amount of SMs. Not quite sure the underlying reason,
but my
*
only
take a small amount of SMs. Not quite sure the underlying reason,
* guess is that too many SMs will cause contention on NVLink bus.
*
but my
guess is that too many SMs will cause contention on NVLink bus.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
int
threads
=
512
,
int
block_limit
=
defaultBlockLimit
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
if
(
size
%
d
!=
0
)
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
@@ -473,8 +567,6 @@ class CustomAllreduce {
...
@@ -473,8 +567,6 @@ class CustomAllreduce {
#define KL(ngpus, name) \
#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
rank_, size);
rank_, size);
// TODO(hanzhi713): Threshold is different for A100 and H100.
// Add per device threshold.
#define REDUCE_CASE(ngpus) \
#define REDUCE_CASE(ngpus) \
case ngpus: { \
case ngpus: { \
if (world_size_ == 2) { \
if (world_size_ == 2) { \
...
@@ -497,7 +589,8 @@ class CustomAllreduce {
...
@@ -497,7 +589,8 @@ class CustomAllreduce {
REDUCE_CASE
(
8
)
REDUCE_CASE
(
8
)
default:
default:
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
"num "
"gpus = "
+
"gpus = "
+
std
::
to_string
(
world_size_
));
std
::
to_string
(
world_size_
));
}
}
...
@@ -511,9 +604,10 @@ class CustomAllreduce {
...
@@ -511,9 +604,10 @@ class CustomAllreduce {
}
}
}
}
};
};
/**
/**
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
add
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
a template instantiation:
add
a template instantiation:
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
half *, int, int, int);
half *, int, int, int);
*/
*/
...
...
csrc/custom_all_reduce_test.cu
View file @
2d6bccd9
This diff is collapsed.
Click to expand it.
csrc/ops.h
View file @
2d6bccd9
...
@@ -374,7 +374,7 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
...
@@ -374,7 +374,7 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
const
std
::
optional
<
at
::
Tensor
>&
has_initial_state
,
bool
silu_activation
,
int64_t
pad_slot_id
);
bool
silu_activation
,
int64_t
pad_slot_id
);
#ifndef USE_ROCM
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
bool
full_nvlink
);
...
@@ -388,4 +388,8 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
...
@@ -388,4 +388,8 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
void
register_graph_buffers
(
fptr_t
_fa
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
std
::
tuple
<
int64_t
,
torch
::
Tensor
>
allocate_shared_buffer_and_handle
(
int64_t
size
);
int64_t
open_mem_handle
(
torch
::
Tensor
&
mem_handle
);
void
free_shared_buffer
(
int64_t
buffer
);
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
2d6bccd9
...
@@ -710,7 +710,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
...
@@ -710,7 +710,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
&
get_max_shared_memory_per_block_device_attribute
);
&
get_max_shared_memory_per_block_device_attribute
);
}
}
#ifndef USE_ROCM
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_custom_ar
),
custom_ar
)
{
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_custom_ar
),
custom_ar
)
{
// Custom all-reduce kernels
// Custom all-reduce kernels
custom_ar
.
def
(
custom_ar
.
def
(
...
@@ -728,7 +728,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
...
@@ -728,7 +728,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
custom_ar
.
def
(
"allocate_shared_buffer_and_handle"
,
&
allocate_shared_buffer_and_handle
);
custom_ar
.
def
(
"open_mem_handle(Tensor mem_handle) -> int"
,
&
open_mem_handle
);
custom_ar
.
impl
(
"open_mem_handle"
,
torch
::
kCPU
,
&
open_mem_handle
);
custom_ar
.
def
(
"free_shared_buffer"
,
&
free_shared_buffer
);
}
}
#endif
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
REGISTER_EXTENSION
(
TORCH_EXTENSION_NAME
)
tests/distributed/test_custom_all_reduce.py
View file @
2d6bccd9
...
@@ -93,7 +93,8 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -93,7 +93,8 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
# communicate independently
# communicate independently
num_communication
=
rank
//
tp_size
+
1
num_communication
=
rank
//
tp_size
+
1
sz
=
1024
sz
=
1024
fa
=
get_tp_group
().
ca_comm
# fa = get_tp_group().ca_comm xiabo
fa
=
get_tp_group
().
device_communicator
.
ca_comm
inp
=
torch
.
ones
(
sz
,
dtype
=
torch
.
float32
,
device
=
device
)
inp
=
torch
.
ones
(
sz
,
dtype
=
torch
.
float32
,
device
=
device
)
out
=
inp
out
=
inp
for
_
in
range
(
num_communication
):
for
_
in
range
(
num_communication
):
...
...
tests/utils.py
View file @
2d6bccd9
...
@@ -574,7 +574,17 @@ def multi_process_parallel(
...
@@ -574,7 +574,17 @@ def multi_process_parallel(
# as compared to multiprocessing.
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
# otherwise we may get import errors on ray workers
ray
.
init
(
num_gpus
=
tp_size
,
runtime_env
=
{
"working_dir"
:
VLLM_PATH
})
# ray.init(num_gpus=tp_size, runtime_env={"working_dir": VLLM_PATH}) xiabo
# NOTE: Force ray not to use gitignore file as excluding, otherwise
# it will not move .so files to working dir.
# So we have to manually add some of large directories
os
.
environ
[
"RAY_RUNTIME_ENV_IGNORE_GITIGNORE"
]
=
"1"
ray
.
init
(
runtime_env
=
{
"working_dir"
:
VLLM_PATH
,
"excludes"
:
[
"build"
,
".git"
,
"cmake-build-*"
,
"shellcheck"
,
"dist"
]
})
distributed_init_port
=
get_open_port
()
distributed_init_port
=
get_open_port
()
refs
=
[]
refs
=
[]
...
...
vllm/_custom_ops.py
View file @
2d6bccd9
...
@@ -1527,6 +1527,18 @@ def register_graph_buffers(fa: int, handles: List[List[int]],
...
@@ -1527,6 +1527,18 @@ def register_graph_buffers(fa: int, handles: List[List[int]],
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
torch
.
ops
.
_C_custom_ar
.
register_graph_buffers
(
fa
,
handles
,
offsets
)
def
allocate_shared_buffer_and_handle
(
size
:
int
)
->
tuple
[
int
,
torch
.
Tensor
]:
return
torch
.
ops
.
_C_custom_ar
.
allocate_shared_buffer_and_handle
(
size
)
def
open_mem_handle
(
mem_handle
:
torch
.
Tensor
):
return
torch
.
ops
.
_C_custom_ar
.
open_mem_handle
(
mem_handle
)
def
free_shared_buffer
(
ptr
:
int
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
free_shared_buffer
(
ptr
)
def
read_cache
(
def
read_cache
(
keys
:
torch
.
Tensor
,
keys
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
values
:
torch
.
Tensor
,
...
...
vllm/config.py
View file @
2d6bccd9
...
@@ -1411,11 +1411,12 @@ class ParallelConfig:
...
@@ -1411,11 +1411,12 @@ class ParallelConfig:
if
self
.
use_ray
:
if
self
.
use_ray
:
from
vllm.executor
import
ray_utils
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
ray_utils
.
assert_ray_available
()
if
current_platform
.
is_rocm
():
# xiabo
self
.
disable_custom_all_reduce
=
True
# if current_platform.is_rocm():
logger
.
info
(
# self.disable_custom_all_reduce = True
"Disabled the custom all-reduce kernel because it is not "
# logger.info(
"supported on hcus."
)
# "Disabled the custom all-reduce kernel because it is not "
# "supported on hcus.")
if
self
.
ray_workers_use_nsight
and
not
self
.
use_ray
:
if
self
.
ray_workers_use_nsight
and
not
self
.
use_ray
:
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
"run with Ray."
)
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
2d6bccd9
...
@@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup
...
@@ -10,7 +10,7 @@ from torch.distributed import ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
#
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
gpu_p2p_access_check
)
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
in_the_same_node_as
from
vllm.distributed.parallel_state
import
in_the_same_node_as
...
@@ -73,6 +73,8 @@ class CustomAllreduce:
...
@@ -73,6 +73,8 @@ class CustomAllreduce:
if
not
custom_ar
:
if
not
custom_ar
:
# disable because of missing custom allreduce library
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
# e.g. in a non-cuda environment
logger
.
warning
(
"Custom allreduce is disabled because "
"of missing custom allreduce library"
)
return
return
self
.
group
=
group
self
.
group
=
group
...
@@ -129,10 +131,12 @@ class CustomAllreduce:
...
@@ -129,10 +131,12 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# this checks hardware and driver support for NVLink
assert
current_platform
.
is_cuda
()
# xiabo
from
vllm.platforms.cuda
import
CudaPlatform
# assert current_platform.is_cuda()
cuda_platform
:
CudaPlatform
=
current_platform
# from vllm.platforms.cuda import CudaPlatform
full_nvlink
=
cuda_platform
.
is_full_nvlink
(
physical_device_ids
)
# cuda_platform: CudaPlatform = current_platform
# full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids)
full_nvlink
=
True
if
world_size
>
2
and
not
full_nvlink
:
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
"Custom allreduce is disabled because it's not supported on"
...
@@ -142,19 +146,22 @@ class CustomAllreduce:
...
@@ -142,19 +146,22 @@ class CustomAllreduce:
# test P2P capability, this checks software/cudaruntime support
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# this is expensive to compute at the first time
# then we cache the result
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
# xiabo
logger
.
warning
(
# if not _can_p2p(rank, world_size):
"Custom allreduce is disabled because your platform lacks "
# logger.warning(
"GPU P2P capability or P2P test failed. To silence this "
# "Custom allreduce is disabled because your platform lacks "
"warning, specify disable_custom_all_reduce=True explicitly."
)
# "GPU P2P capability or P2P test failed. To silence this "
return
# "warning, specify disable_custom_all_reduce=True explicitly.")
# return
self
.
disabled
=
False
self
.
disabled
=
False
# Buffers memory are owned by this Python class and passed to C++.
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
# temporary buffer for storing intermediate allreduce results.
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
self
.
meta_ptrs
=
self
.
create_shared_buffer
(
ops
.
meta_size
()
+
max_size
,
group
=
group
)
group
=
group
,
uncached
=
True
)
# group=group) xiabo
# This is a pre-registered IPC buffer. In eager mode, input tensors
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
# are first copied into this buffer before allreduce is performed
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
max_size
,
group
=
group
)
...
@@ -174,38 +181,38 @@ class CustomAllreduce:
...
@@ -174,38 +181,38 @@ class CustomAllreduce:
self
.
full_nvlink
)
self
.
full_nvlink
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
ops
.
register_buffer
(
self
.
_ptr
,
self
.
buffer_ptrs
)
@
staticmethod
#
@staticmethod
def
create_shared_buffer
(
#
def create_shared_buffer(
size_in_bytes
:
int
,
#
size_in_bytes: int,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
List
[
int
]:
#
group: Optional[ProcessGroup] = None) -> List[int]:
"""
#
"""
Creates a shared buffer and returns a list of pointers
#
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
#
representing the buffer on all processes in the group.
"""
#
"""
lib
=
CudaRTLibrary
()
#
lib = CudaRTLibrary()
pointer
=
lib
.
cudaMalloc
(
size_in_bytes
)
#
pointer = lib.cudaMalloc(size_in_bytes)
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
#
handle = lib.cudaIpcGetMemHandle(pointer)
world_size
=
dist
.
get_world_size
(
group
=
group
)
#
world_size = dist.get_world_size(group=group)
rank
=
dist
.
get_rank
(
group
=
group
)
#
rank = dist.get_rank(group=group)
handles
=
[
None
]
*
world_size
#
handles = [None] * world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
#
dist.all_gather_object(handles, handle, group=group)
pointers
:
List
[
int
]
=
[]
#
pointers: List[int] = []
for
i
,
h
in
enumerate
(
handles
):
#
for i, h in enumerate(handles):
if
i
==
rank
:
#
if i == rank:
pointers
.
append
(
pointer
.
value
)
# type: ignore
#
pointers.append(pointer.value) # type: ignore
else
:
#
else:
pointers
.
append
(
#
pointers.append(
lib
.
cudaIpcOpenMemHandle
(
h
).
value
)
# type: ignore
#
lib.cudaIpcOpenMemHandle(h).value) # type: ignore
return
pointers
#
return pointers
@
staticmethod
#
@staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
#
def free_shared_buffer(pointers: List[int],
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
#
group: Optional[ProcessGroup] = None) -> None:
rank
=
dist
.
get_rank
(
group
=
group
)
#
rank = dist.get_rank(group=group)
lib
=
CudaRTLibrary
()
#
lib = CudaRTLibrary()
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
#
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
@
contextmanager
@
contextmanager
def
capture
(
self
):
def
capture
(
self
):
...
@@ -304,3 +311,30 @@ class CustomAllreduce:
...
@@ -304,3 +311,30 @@ class CustomAllreduce:
def
__del__
(
self
):
def
__del__
(
self
):
self
.
close
()
self
.
close
()
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
,
uncached
:
Optional
[
bool
]
=
False
)
->
List
[
int
]:
pointer
,
handle
=
ops
.
allocate_shared_buffer_and_handle
(
size_in_bytes
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
List
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
)
# type: ignore
else
:
pointers
.
append
(
ops
.
open_mem_handle
(
h
))
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
,
rank
:
Optional
[
int
]
=
0
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
ops
.
free_shared_buffer
(
pointers
[
rank
])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment