Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1dba2c4e
Unverified
Commit
1dba2c4e
authored
Jul 04, 2025
by
Ning Xie
Committed by
GitHub
Jul 03, 2025
Browse files
[Misc] adjust for ipv6 for mookcacke url parse (#20107)
Signed-off-by:
Andy Xie
<
andy.xning@gmail.com
>
parent
71d6de3a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
27 deletions
+99
-27
tests/test_utils.py
tests/test_utils.py
+46
-4
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
+29
-19
vllm/utils/__init__.py
vllm/utils/__init__.py
+24
-4
No files found.
tests/test_utils.py
View file @
1dba2c4e
...
...
@@ -20,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from
vllm.utils
import
(
CacheInfo
,
FlexibleArgumentParser
,
LRUCache
,
MemorySnapshot
,
PlaceholderModule
,
StoreBoolean
,
bind_kv_cache
,
common_broadcastable_dtype
,
deprecate_kwargs
,
get_open_port
,
is_lossless_cast
,
make_zmq_path
,
make_zmq_socket
,
memory_profiling
,
merge_async_iterators
,
sha256
,
split_zmq_path
,
supports_kw
,
swap_dict_values
)
deprecate_kwargs
,
get_open_port
,
get_tcp_uri
,
is_lossless_cast
,
join_host_port
,
make_zmq_path
,
make_zmq_socket
,
memory_profiling
,
merge_async_iterators
,
sha256
,
split_host_port
,
split_zmq_path
,
supports_kw
,
swap_dict_values
)
from
.utils
import
create_new_process_for_each_test
,
error_on_warning
...
...
@@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6():
def
test_make_zmq_path
():
assert
make_zmq_path
(
"tcp"
,
"127.0.0.1"
,
"5555"
)
==
"tcp://127.0.0.1:5555"
assert
make_zmq_path
(
"tcp"
,
"::1"
,
"5555"
)
==
"tcp://[::1]:5555"
def
test_get_tcp_uri
():
assert
get_tcp_uri
(
"127.0.0.1"
,
5555
)
==
"tcp://127.0.0.1:5555"
assert
get_tcp_uri
(
"::1"
,
5555
)
==
"tcp://[::1]:5555"
def
test_split_host_port
():
# valid ipv4
assert
split_host_port
(
"127.0.0.1:5555"
)
==
(
"127.0.0.1"
,
5555
)
# invalid ipv4
with
pytest
.
raises
(
ValueError
):
# multi colon
assert
split_host_port
(
"127.0.0.1::5555"
)
with
pytest
.
raises
(
ValueError
):
# tailing colon
assert
split_host_port
(
"127.0.0.1:5555:"
)
with
pytest
.
raises
(
ValueError
):
# no colon
assert
split_host_port
(
"127.0.0.15555"
)
with
pytest
.
raises
(
ValueError
):
# none int port
assert
split_host_port
(
"127.0.0.1:5555a"
)
# valid ipv6
assert
split_host_port
(
"[::1]:5555"
)
==
(
"::1"
,
5555
)
# invalid ipv6
with
pytest
.
raises
(
ValueError
):
# multi colon
assert
split_host_port
(
"[::1]::5555"
)
with
pytest
.
raises
(
IndexError
):
# no colon
assert
split_host_port
(
"[::1]5555"
)
with
pytest
.
raises
(
ValueError
):
# none int port
assert
split_host_port
(
"[::1]:5555a"
)
def
test_join_host_port
():
assert
join_host_port
(
"127.0.0.1"
,
5555
)
==
"127.0.0.1:5555"
assert
join_host_port
(
"::1"
,
5555
)
==
"[::1]:5555"
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
View file @
1dba2c4e
...
...
@@ -16,6 +16,7 @@ from safetensors.torch import save as safetensors_save
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_pipe.base
import
KVPipeBase
from
vllm.logger
import
init_logger
from
vllm.utils
import
join_host_port
,
make_zmq_path
,
split_host_port
logger
=
init_logger
(
__name__
)
NONE_INT
=
-
150886311
...
...
@@ -79,18 +80,19 @@ class MooncakeTransferEngine:
logger
.
error
(
"An error occurred while loading the configuration: %s"
,
exc
)
raise
prefill_host
,
base_prefill_port
=
self
.
config
.
prefill_url
.
split
(
':'
)
decode_host
,
base_decode_port
=
self
.
config
.
decode_url
.
split
(
':'
)
prefill_host
,
base_prefill_port
=
split_host_port
(
self
.
config
.
prefill_url
)
decode_host
,
base_decode_port
=
split_host_port
(
self
.
config
.
decode_url
)
# Avoid ports conflict when running prefill and decode on the same node
if
prefill_host
==
decode_host
and
\
base_prefill_port
==
base_decode_port
:
base_decode_port
=
str
(
int
(
base_decode_port
)
+
100
)
base_decode_port
=
base_decode_port
+
100
prefill_port
=
int
(
base_prefill_port
)
+
self
.
local_rank
decode_port
=
int
(
base_decode_port
)
+
self
.
local_rank
self
.
prefill_url
=
':'
.
join
([
prefill_host
,
str
(
prefill_port
)
])
self
.
decode_url
=
':'
.
join
([
decode_host
,
str
(
decode_port
)
])
prefill_port
=
base_prefill_port
+
self
.
local_rank
decode_port
=
base_decode_port
+
self
.
local_rank
self
.
prefill_url
=
join_host_port
(
prefill_host
,
prefill_port
)
self
.
decode_url
=
join_host_port
(
decode_host
,
decode_port
)
self
.
initialize
(
self
.
prefill_url
if
kv_rank
==
0
else
self
.
decode_url
,
self
.
config
.
metadata_server
,
self
.
config
.
protocol
,
...
...
@@ -110,22 +112,30 @@ class MooncakeTransferEngine:
self
.
_setup_metadata_sockets
(
kv_rank
,
prefill_host
,
base_prefill_port
,
decode_host
,
base_decode_port
)
def
_setup_metadata_sockets
(
self
,
kv_rank
:
int
,
p_host
:
str
,
p_port
:
str
,
d_host
:
str
,
d_port
:
str
)
->
None
:
def
_setup_metadata_sockets
(
self
,
kv_rank
:
int
,
p_host
:
str
,
p_port
:
int
,
d_host
:
str
,
d_port
:
int
)
->
None
:
"""Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset
=
int
(
p_port
)
+
8
+
self
.
local_rank
*
2
d_rank_offset
=
int
(
d_port
)
+
8
+
self
.
local_rank
*
2
p_rank_offset
=
p_port
+
8
+
self
.
local_rank
*
2
d_rank_offset
=
d_port
+
8
+
self
.
local_rank
*
2
if
kv_rank
==
0
:
self
.
sender_socket
.
bind
(
f
"tcp://
{
p_host
}
:
{
p_rank_offset
+
1
}
"
)
self
.
receiver_socket
.
connect
(
f
"tcp://
{
d_host
}
:
{
d_rank_offset
+
1
}
"
)
self
.
sender_ack
.
connect
(
f
"tcp://
{
d_host
}
:
{
d_rank_offset
+
2
}
"
)
self
.
receiver_ack
.
bind
(
f
"tcp://
{
p_host
}
:
{
p_rank_offset
+
2
}
"
)
self
.
sender_socket
.
bind
(
make_zmq_path
(
"tcp"
,
p_host
,
p_rank_offset
+
1
))
self
.
receiver_socket
.
connect
(
make_zmq_path
(
"tcp"
,
d_host
,
d_rank_offset
+
1
))
self
.
sender_ack
.
connect
(
make_zmq_path
(
"tcp"
,
d_host
,
d_rank_offset
+
2
))
self
.
receiver_ack
.
bind
(
make_zmq_path
(
"tcp"
,
p_host
,
p_rank_offset
+
2
))
else
:
self
.
receiver_socket
.
connect
(
f
"tcp://
{
p_host
}
:
{
p_rank_offset
+
1
}
"
)
self
.
sender_socket
.
bind
(
f
"tcp://
{
d_host
}
:
{
d_rank_offset
+
1
}
"
)
self
.
receiver_ack
.
bind
(
f
"tcp://
{
d_host
}
:
{
d_rank_offset
+
2
}
"
)
self
.
sender_ack
.
connect
(
f
"tcp://
{
p_host
}
:
{
p_rank_offset
+
2
}
"
)
self
.
receiver_socket
.
connect
(
make_zmq_path
(
"tcp"
,
p_host
,
p_rank_offset
+
1
))
self
.
sender_socket
.
bind
(
make_zmq_path
(
"tcp"
,
d_host
,
d_rank_offset
+
1
))
self
.
receiver_ack
.
bind
(
make_zmq_path
(
"tcp"
,
d_host
,
d_rank_offset
+
2
))
self
.
sender_ack
.
connect
(
make_zmq_path
(
"tcp"
,
p_host
,
p_rank_offset
+
2
))
def
initialize
(
self
,
local_hostname
:
str
,
metadata_server
:
str
,
protocol
:
str
,
device_name
:
str
,
...
...
vllm/utils/__init__.py
View file @
1dba2c4e
...
...
@@ -46,7 +46,7 @@ from dataclasses import dataclass, field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
types
import
MappingProxyType
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
TypeVar
,
Union
,
cast
,
overload
)
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
,
overload
)
from
urllib.parse
import
urlparse
from
uuid
import
uuid4
...
...
@@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool:
return
False
def
split_host_port
(
host_port
:
str
)
->
Tuple
[
str
,
int
]:
# ipv6
if
host_port
.
startswith
(
'['
):
host
,
port
=
host_port
.
rsplit
(
']'
,
1
)
host
=
host
[
1
:]
port
=
port
.
split
(
':'
)[
1
]
return
host
,
int
(
port
)
else
:
host
,
port
=
host_port
.
split
(
':'
)
return
host
,
int
(
port
)
def
join_host_port
(
host
:
str
,
port
:
int
)
->
str
:
if
is_valid_ipv6_address
(
host
):
return
f
"[
{
host
}
]:
{
port
}
"
else
:
return
f
"
{
host
}
:
{
port
}
"
def
get_distributed_init_method
(
ip
:
str
,
port
:
int
)
->
str
:
return
get_tcp_uri
(
ip
,
port
)
def
get_tcp_uri
(
ip
:
str
,
port
:
int
)
->
str
:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return
f
"tcp://[
{
ip
}
]:
{
port
}
"
if
":"
in
ip
else
f
"tcp://
{
ip
}
:
{
port
}
"
if
is_valid_ipv6_address
(
ip
):
return
f
"tcp://[
{
ip
}
]:
{
port
}
"
else
:
return
f
"tcp://
{
ip
}
:
{
port
}
"
def
get_open_zmq_ipc_path
()
->
str
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment