Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4be3a451
Unverified
Commit
4be3a451
authored
Nov 05, 2024
by
youkaichao
Committed by
GitHub
Nov 05, 2024
Browse files
[distributed] add function to create ipc buffers directly (#10064)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
40899855
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
91 additions
and
0 deletions
+91
-0
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
tests/distributed/test_ca_buffer_sharing.py
tests/distributed/test_ca_buffer_sharing.py
+59
-0
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+31
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
4be3a451
...
...
@@ -510,6 +510,7 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
-
pytest -v -s distributed/test_custom_all_reduce.py
-
torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py
-
TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus
-
pytest -v -s -x lora/test_mixtral.py
...
...
tests/distributed/test_ca_buffer_sharing.py
0 → 100644
View file @
4be3a451
# can only run on machines with p2p access across GPUs
# can only run with torchrun:
# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py
import
ctypes
import
torch
import
torch.distributed
as
dist
from
vllm.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
# noqa
CustomAllreduce
)
# create a cpu process group for communicating metadata (ipc handle)
dist
.
init_process_group
(
backend
=
"gloo"
)
rank
=
local_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# every process sets its own device (differently)
lib
=
CudaRTLibrary
()
lib
.
cudaSetDevice
(
rank
)
buffer_size_in_bytes
=
1024
byte_value
=
2
# the value we write to the buffer for verification
pointers
=
CustomAllreduce
.
create_shared_buffer
(
buffer_size_in_bytes
)
print
(
f
"Rank
{
rank
}
has pointers
{
pointers
}
"
)
dist
.
barrier
()
torch
.
cuda
.
synchronize
()
if
rank
==
0
:
# the first rank tries to write to all buffers
for
p
in
pointers
:
pointer
=
ctypes
.
c_void_p
(
p
)
lib
.
cudaMemset
(
pointer
,
byte_value
,
buffer_size_in_bytes
)
dist
.
barrier
()
torch
.
cuda
.
synchronize
()
host_data
=
(
ctypes
.
c_char
*
buffer_size_in_bytes
)()
# all ranks read from all buffers, and check if the data is correct
for
p
in
pointers
:
pointer
=
ctypes
.
c_void_p
(
p
)
lib
.
cudaMemcpy
(
host_data
,
pointer
,
buffer_size_in_bytes
)
for
i
in
range
(
buffer_size_in_bytes
):
assert
ord
(
host_data
[
i
])
==
byte_value
,
(
f
"Rank
{
rank
}
failed"
f
" to verify buffer
{
p
}
. Expected
{
byte_value
}
, "
f
"got
{
ord
(
host_data
[
i
])
}
"
)
print
(
f
"Rank
{
rank
}
verified all buffers"
)
dist
.
barrier
()
torch
.
cuda
.
synchronize
()
CustomAllreduce
.
free_shared_buffer
(
pointers
)
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
4be3a451
import
ctypes
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Optional
,
Union
...
...
@@ -7,6 +8,7 @@ from torch.distributed import ProcessGroup
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
in_the_same_node_as
...
...
@@ -174,6 +176,35 @@ class CustomAllreduce:
offsets
,
rank
,
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
List
[
int
]:
lib
=
CudaRTLibrary
()
pointer
=
lib
.
cudaMalloc
(
size_in_bytes
)
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
List
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
.
value
)
# type: ignore
else
:
pointers
.
append
(
lib
.
cudaIpcOpenMemHandle
(
h
).
value
)
# type: ignore
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
lib
=
CudaRTLibrary
()
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
@
contextmanager
def
capture
(
self
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment