Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5bb85b4
"tests/vscode:/vscode.git/clone" did not exist on "9239bf718e5ebb5ab871ac8ed09fb80ed02fa82b"
Unverified
Commit
f5bb85b4
authored
Jun 14, 2024
by
youkaichao
Committed by
GitHub
Jun 14, 2024
Browse files
[Core][Distributed] improve p2p cache generation (#5528)
parent
28c145eb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
265 additions
and
96 deletions
+265
-96
vllm/distributed/device_communicators/cuda_wrapper.py
vllm/distributed/device_communicators/cuda_wrapper.py
+146
-0
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+119
-96
No files found.
vllm/distributed/device_communicators/cuda_wrapper.py
0 → 100644
View file @
f5bb85b4
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""
import
ctypes
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
# this line makes it possible to directly load `libcudart.so` using `ctypes`
import
torch
# noqa
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
cudaError_t
=
ctypes
.
c_int
cudaMemcpyKind
=
ctypes
.
c_int
class
cudaIpcMemHandle_t
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
@
dataclass
class
Function
:
name
:
str
restype
:
Any
argtypes
:
List
[
Any
]
class
CudaRTLibrary
:
exported_functions
=
[
# cudaError_t cudaSetDevice ( int device )
Function
(
"cudaSetDevice"
,
cudaError_t
,
[
ctypes
.
c_int
]),
# cudaError_t cudaDeviceSynchronize ( void )
Function
(
"cudaDeviceSynchronize"
,
cudaError_t
,
[]),
# cudaError_t cudaDeviceReset ( void )
Function
(
"cudaDeviceReset"
,
cudaError_t
,
[]),
# const char* cudaGetErrorString ( cudaError_t error )
Function
(
"cudaGetErrorString"
,
ctypes
.
c_char_p
,
[
cudaError_t
]),
# cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function
(
"cudaMalloc"
,
cudaError_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
),
ctypes
.
c_size_t
]),
# cudaError_t cudaFree ( void* devPtr )
Function
(
"cudaFree"
,
cudaError_t
,
[
ctypes
.
c_void_p
]),
# cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function
(
"cudaMemset"
,
cudaError_t
,
[
ctypes
.
c_void_p
,
ctypes
.
c_int
,
ctypes
.
c_size_t
]),
# cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function
(
"cudaMemcpy"
,
cudaError_t
,
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
,
ctypes
.
c_size_t
,
cudaMemcpyKind
]),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function
(
"cudaIpcGetMemHandle"
,
cudaError_t
,
[
ctypes
.
POINTER
(
cudaIpcMemHandle_t
),
ctypes
.
c_void_p
]),
# cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function
(
"cudaIpcOpenMemHandle"
,
cudaError_t
,
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
),
cudaIpcMemHandle_t
,
ctypes
.
c_uint
]),
]
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache
:
Dict
[
str
,
Any
]
=
{}
# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{}
def
__init__
(
self
,
so_file
:
Optional
[
str
]
=
None
):
if
so_file
is
None
:
assert
torch
.
version
.
cuda
is
not
None
major_version
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
so_file
=
f
"libcudart.so.
{
major_version
}
"
if
so_file
not
in
CudaRTLibrary
.
path_to_library_cache
:
lib
=
ctypes
.
CDLL
(
so_file
)
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
=
lib
self
.
lib
=
CudaRTLibrary
.
path_to_library_cache
[
so_file
]
if
so_file
not
in
CudaRTLibrary
.
path_to_dict_mapping
:
_funcs
=
{}
for
func
in
CudaRTLibrary
.
exported_functions
:
f
=
getattr
(
self
.
lib
,
func
.
name
)
f
.
restype
=
func
.
restype
f
.
argtypes
=
func
.
argtypes
_funcs
[
func
.
name
]
=
f
CudaRTLibrary
.
path_to_dict_mapping
[
so_file
]
=
_funcs
self
.
funcs
=
CudaRTLibrary
.
path_to_dict_mapping
[
so_file
]
def
CUDART_CHECK
(
self
,
result
:
cudaError_t
)
->
None
:
if
result
!=
0
:
error_str
=
self
.
cudaGetErrorString
(
result
)
raise
RuntimeError
(
f
"CUDART error:
{
error_str
}
"
)
def
cudaGetErrorString
(
self
,
error
:
cudaError_t
)
->
str
:
return
self
.
funcs
[
"cudaGetErrorString"
](
error
).
decode
(
"utf-8"
)
def
cudaSetDevice
(
self
,
device
:
int
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaSetDevice"
](
device
))
def
cudaDeviceSynchronize
(
self
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaDeviceSynchronize"
]())
def
cudaDeviceReset
(
self
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaDeviceReset"
]())
def
cudaMalloc
(
self
,
size
:
int
)
->
ctypes
.
c_void_p
:
devPtr
=
ctypes
.
c_void_p
()
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaMalloc"
](
ctypes
.
byref
(
devPtr
),
size
))
return
devPtr
def
cudaFree
(
self
,
devPtr
:
ctypes
.
c_void_p
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaFree"
](
devPtr
))
def
cudaMemset
(
self
,
devPtr
:
ctypes
.
c_void_p
,
value
:
int
,
count
:
int
)
->
None
:
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaMemset"
](
devPtr
,
value
,
count
))
def
cudaMemcpy
(
self
,
dst
:
ctypes
.
c_void_p
,
src
:
ctypes
.
c_void_p
,
count
:
int
)
->
None
:
cudaMemcpyDefault
=
4
kind
=
cudaMemcpyDefault
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaMemcpy"
](
dst
,
src
,
count
,
kind
))
def
cudaIpcGetMemHandle
(
self
,
devPtr
:
ctypes
.
c_void_p
)
->
cudaIpcMemHandle_t
:
handle
=
cudaIpcMemHandle_t
()
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaIpcGetMemHandle"
](
ctypes
.
byref
(
handle
),
devPtr
))
return
handle
def
cudaIpcOpenMemHandle
(
self
,
handle
:
cudaIpcMemHandle_t
)
->
ctypes
.
c_void_p
:
cudaIpcMemLazyEnablePeerAccess
=
1
devPtr
=
ctypes
.
c_void_p
()
self
.
CUDART_CHECK
(
self
.
funcs
[
"cudaIpcOpenMemHandle"
](
ctypes
.
byref
(
devPtr
),
handle
,
cudaIpcMemLazyEnablePeerAccess
))
return
devPtr
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
f5bb85b4
import
ctypes
import
json
import
json
import
os
import
os
import
sys
from
itertools
import
product
import
tempfile
from
typing
import
Dict
,
Optional
,
Sequence
import
time
from
contextlib
import
contextmanager
from
typing
import
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
contextmanager
def
producer
(
batch_src
:
Sequence
[
int
],
def
mute_output
():
producer_queue
,
with
open
(
os
.
devnull
,
"w"
)
as
f
:
consumer_queue
,
sys
.
stderr
=
f
result_queue
,
sys
.
stdout
=
f
yield
def
producer
(
i
:
int
,
init_method
:
str
,
cuda_visible_devices
:
Optional
[
str
]
=
None
):
cuda_visible_devices
:
Optional
[
str
]
=
None
):
if
cuda_visible_devices
is
not
None
:
if
cuda_visible_devices
is
not
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_visible_devices
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_visible_devices
with
mute_output
():
dist
.
init_process_group
(
lib
=
CudaRTLibrary
()
backend
=
"gloo"
,
for
i
in
batch_src
:
init_method
=
init_method
,
lib
.
cudaSetDevice
(
i
)
world_size
=
2
,
pointer
=
lib
.
cudaMalloc
(
1024
)
rank
=
0
,
lib
.
cudaMemset
(
pointer
,
1
,
1024
)
)
lib
.
cudaDeviceSynchronize
()
# produce a tensor in GPU i
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
data
=
torch
.
zeros
((
128
,
),
device
=
f
"cuda:
{
i
}
"
)
producer_queue
.
put
(
handle
)
# get the information to reconstruct the shared tensor
open_success
=
consumer_queue
.
get
()
func
,
args
=
torch
.
multiprocessing
.
reductions
.
reduce_tensor
(
data
)
if
open_success
:
args
=
list
(
args
)
# use two queues to simulate barrier
dist
.
broadcast_object_list
([(
func
,
args
)],
src
=
0
)
producer_queue
.
put
(
0
)
dist
.
barrier
()
consumer_queue
.
get
()
torch
.
cuda
.
synchronize
()
# check if the memory is modified
assert
torch
.
all
(
data
==
1
).
item
()
host_data
=
(
ctypes
.
c_char
*
1024
)()
lib
.
cudaMemcpy
(
host_data
,
pointer
,
1024
)
# type: ignore
for
i
in
range
(
1024
):
def
consumer
(
j
:
int
,
if
ord
(
host_data
[
i
])
!=
2
:
init_method
:
str
,
open_success
=
False
break
result_queue
.
put
(
open_success
)
lib
.
cudaDeviceReset
()
def
consumer
(
batch_tgt
:
Sequence
[
int
],
producer_queue
,
consumer_queue
,
result_queue
,
cuda_visible_devices
:
Optional
[
str
]
=
None
):
cuda_visible_devices
:
Optional
[
str
]
=
None
):
if
cuda_visible_devices
is
not
None
:
if
cuda_visible_devices
is
not
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_visible_devices
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
cuda_visible_devices
with
mute_output
():
dist
.
init_process_group
(
lib
=
CudaRTLibrary
()
backend
=
"gloo"
,
for
j
in
batch_tgt
:
init_method
=
init_method
,
lib
.
cudaSetDevice
(
j
)
world_size
=
2
,
handle
=
producer_queue
.
get
()
rank
=
1
,
open_success
=
False
)
try
:
torch
.
cuda
.
set_device
(
j
)
pointer
=
lib
.
cudaIpcOpenMemHandle
(
handle
)
# type: ignore
recv
=
[
None
]
open_success
=
True
dist
.
broadcast_object_list
(
recv
,
src
=
0
)
except
RuntimeError
:
func
:
Callable
# cannot error out here, because the producer process
args
:
List
# is still waiting for the response.
func
,
args
=
recv
[
0
]
# type: ignore
pass
# `args[6]` is the device id
consumer_queue
.
put
(
open_success
)
# by default pytorch will use `i` from the producer
if
open_success
:
# here we need to set it to `j` to test P2P access
# modify the memory
args
[
6
]
=
j
lib
.
cudaMemset
(
pointer
,
2
,
1024
)
data
=
func
(
*
args
)
# use two queues to simulate barrier
data
+=
1
producer_queue
.
get
()
dist
.
barrier
()
consumer_queue
.
put
(
0
)
torch
.
cuda
.
synchronize
()
# check if the memory is modified
assert
torch
.
all
(
data
==
1
).
item
()
host_data
=
(
ctypes
.
c_char
*
1024
)()
lib
.
cudaMemcpy
(
host_data
,
pointer
,
1024
)
# type: ignore
for
i
in
range
(
1024
):
def
can_actually_p2p
(
i
,
j
):
if
ord
(
host_data
[
i
])
!=
2
:
open_success
=
False
break
result_queue
.
put
(
open_success
)
lib
.
cudaDeviceReset
()
def
can_actually_p2p
(
batch_src
:
Sequence
[
int
],
batch_tgt
:
Sequence
[
int
],
):
"""
"""
Usually, checking if P2P access is enabled can be done by
Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(
i, j
)`. However, sometimes
`torch.cuda.can_device_access_peer(
src, tgt
)`. However, sometimes
the driver might be broken, and `torch.cuda.can_device_access_peer(
i, j
)`
the driver might be broken, and `torch.cuda.can_device_access_peer(
src, tgt
)`
returns `True` even if P2P access is not actually possible.
returns `True` even if P2P access is not actually possible.
See https://github.com/vllm-project/vllm/issues/2728 and
See https://github.com/vllm-project/vllm/issues/2728 and
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
...
@@ -90,41 +101,50 @@ def can_actually_p2p(i, j):
...
@@ -90,41 +101,50 @@ def can_actually_p2p(i, j):
Note on p2p and cuda IPC:
Note on p2p and cuda IPC:
Usually, one process uses one GPU:
Usually, one process uses one GPU:
GPU
i
--> cuda context
i
--> tensor
i
--> process
i
GPU
src
--> cuda context
src
--> tensor
src
--> process
src
We need to combine p2p and cuda IPC, so that:
We need to combine p2p and cuda IPC, so that:
GPU
i
--> cuda context
i
--> tensor
i
--> process
i
GPU
src
--> cuda context
src
--> tensor
src
--> process
src
|shared|
|shared|
GPU
j
--> cuda context
j
--> tensor
j
--> process
j
GPU
tgt
--> cuda context
tgt
--> tensor
tgt
--> process
tgt
That is to say, process
i
creates a tensor in GPU
i
, passes IPC handle to
That is to say, process
src
creates a tensor in GPU
src
, passes IPC handle to
process
j
, and process
j
accesses the tensor in GPU
j
. Any operation on the
process
tgt
, and process
tgt
accesses the tensor in GPU
tgt
. Any operation on the
tensor in process
j
will be reflected in the tensor in process
i
, because
tensor in process
tgt
will be reflected in the tensor in process
src
, because
they are the same memory segment.
they are the same memory segment.
It is important to note that process j accesses the tensor in GPU j, not
It is important to note that process tgt accesses the tensor in GPU tgt, not
GPU i. That's why we need p2p access. # noqa
GPU src. That's why we need p2p access.
"""
The most time-consuming part is the process creation. To avoid creating
processes for every pair of GPUs, we use batched testing. We create two
processes for testing all pairs of GPUs in batch. The trick is to reset
the device after each test (which is not available in PyTorch).
"""
# noqa
cuda_visible_devices
=
os
.
getenv
(
'CUDA_VISIBLE_DEVICES'
,
None
)
cuda_visible_devices
=
os
.
getenv
(
'CUDA_VISIBLE_DEVICES'
,
None
)
# pass the CUDA_VISIBLE_DEVICES to the child process
# pass the CUDA_VISIBLE_DEVICES to the child process
# to make sure they see the same set of GPUs
# to make sure they see the same set of GPUs
# make sure the temp file is not the same across different calls
temp_path
=
tempfile
.
mktemp
()
+
str
(
time
.
time
())
# create an empty file
with
open
(
temp_path
,
"w"
):
pass
init_method
=
f
"file://
{
temp_path
}
"
# make sure the processes are spawned
# make sure the processes are spawned
smp
=
mp
.
get_context
(
"spawn"
)
smp
=
mp
.
get_context
(
"spawn"
)
pi
=
smp
.
Process
(
target
=
producer
,
producer_queue
=
smp
.
Queue
()
args
=
(
i
,
init_method
,
cuda_visible_devices
))
consumer_queue
=
smp
.
Queue
()
pj
=
smp
.
Process
(
target
=
consumer
,
result_queue
=
smp
.
Queue
()
args
=
(
j
,
init_method
,
cuda_visible_devices
))
p_src
=
smp
.
Process
(
target
=
producer
,
pi
.
start
()
args
=
(
batch_src
,
producer_queue
,
consumer_queue
,
pj
.
start
()
result_queue
,
cuda_visible_devices
))
pi
.
join
()
p_tgt
=
smp
.
Process
(
target
=
consumer
,
pj
.
join
()
args
=
(
batch_tgt
,
producer_queue
,
consumer_queue
,
return
pi
.
exitcode
==
0
and
pj
.
exitcode
==
0
result_queue
,
cuda_visible_devices
))
p_src
.
start
()
p_tgt
.
start
()
p_src
.
join
()
p_tgt
.
join
()
result
=
[]
for
src
,
tgt
in
zip
(
batch_src
,
batch_tgt
):
a
=
result_queue
.
get
()
b
=
result_queue
.
get
()
assert
a
==
b
result
.
append
(
a
)
return
result
# why do we need this cache?
# why do we need this cache?
...
@@ -142,14 +162,14 @@ def can_actually_p2p(i, j):
...
@@ -142,14 +162,14 @@ def can_actually_p2p(i, j):
_gpu_p2p_access_cache
:
Optional
[
Dict
[
str
,
bool
]]
=
None
_gpu_p2p_access_cache
:
Optional
[
Dict
[
str
,
bool
]]
=
None
def
gpu_p2p_access_check
(
i
:
int
,
j
:
int
)
->
bool
:
def
gpu_p2p_access_check
(
src
:
int
,
tgt
:
int
)
->
bool
:
"""Check if GPU
i
can access GPU
j
."""
"""Check if GPU
src
can access GPU
tgt
."""
# if the cache variable is already calculated,
# if the cache variable is already calculated,
# read from the cache instead of checking it again
# read from the cache instead of checking it again
global
_gpu_p2p_access_cache
global
_gpu_p2p_access_cache
if
_gpu_p2p_access_cache
is
not
None
:
if
_gpu_p2p_access_cache
is
not
None
:
return
_gpu_p2p_access_cache
[
f
"
{
i
}
->
{
j
}
"
]
return
_gpu_p2p_access_cache
[
f
"
{
src
}
->
{
tgt
}
"
]
is_distributed
=
dist
.
is_initialized
()
is_distributed
=
dist
.
is_initialized
()
...
@@ -169,9 +189,12 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
...
@@ -169,9 +189,12 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
# enter this block to calculate the cache
# enter this block to calculate the cache
logger
.
info
(
"generating GPU P2P access cache in %s"
,
path
)
logger
.
info
(
"generating GPU P2P access cache in %s"
,
path
)
cache
=
{}
cache
=
{}
for
_i
in
range
(
num_dev
):
ids
=
list
(
range
(
num_dev
))
for
_j
in
range
(
num_dev
):
# batch of all pairs of GPUs
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
can_actually_p2p
(
_i
,
_j
)
batch_src
,
batch_tgt
=
zip
(
*
list
(
product
(
ids
,
ids
)))
result
=
can_actually_p2p
(
batch_src
,
batch_tgt
)
for
_i
,
_j
,
r
in
zip
(
batch_src
,
batch_tgt
,
result
):
cache
[
f
"
{
_i
}
->
{
_j
}
"
]
=
r
with
open
(
path
,
"w"
)
as
f
:
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
cache
,
f
,
indent
=
4
)
json
.
dump
(
cache
,
f
,
indent
=
4
)
if
is_distributed
:
if
is_distributed
:
...
@@ -180,7 +203,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
...
@@ -180,7 +203,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
with
open
(
path
,
"r"
)
as
f
:
with
open
(
path
,
"r"
)
as
f
:
cache
=
json
.
load
(
f
)
cache
=
json
.
load
(
f
)
_gpu_p2p_access_cache
=
cache
_gpu_p2p_access_cache
=
cache
return
_gpu_p2p_access_cache
[
f
"
{
i
}
->
{
j
}
"
]
return
_gpu_p2p_access_cache
[
f
"
{
src
}
->
{
tgt
}
"
]
__all__
=
[
"gpu_p2p_access_check"
]
__all__
=
[
"gpu_p2p_access_check"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment