Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
09593e9b
Unverified
Commit
09593e9b
authored
Jun 17, 2024
by
Ying Sheng
Committed by
GitHub
Jun 17, 2024
Browse files
Multi-node Tensor Parallelism (#550)
Co-authored-by:
Lianmin Zheng
<
lianminzheng@gmail.com
>
parent
53a7ebd8
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
167 additions
and
46 deletions
+167
-46
benchmark/latency_throughput/README.md
benchmark/latency_throughput/README.md
+2
-2
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+1
-1
python/sglang/launch_server.py
python/sglang/launch_server.py
+2
-1
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+2
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-3
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+14
-8
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+5
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+25
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+24
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+85
-25
No files found.
benchmark/latency_throughput/README.md
View file @
09593e9b
...
@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat
...
@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat
```
```
# run synthetic
# run synthetic
python3
synthetic_benchmark
.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
python3
bench_throughput
.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
```
```
...
@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha
...
@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha
```
```
# run synthetic
# run synthetic
python3
synthetic_benchmark
.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
python3
bench_throughput
.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
```
```
...
...
benchmark/latency_throughput/test_latency.py
View file @
09593e9b
...
@@ -24,7 +24,7 @@ if __name__ == "__main__":
...
@@ -24,7 +24,7 @@ if __name__ == "__main__":
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
a
=
random
.
randint
(
0
,
1
<<
20
)
a
=
20
max_new_tokens
=
256
max_new_tokens
=
256
prompt
=
f
"
{
a
,
}
"
prompt
=
f
"
{
a
,
}
"
...
...
python/sglang/launch_server.py
View file @
09593e9b
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
import
argparse
import
argparse
from
sglang.srt.server
import
ServerArgs
,
launch_server
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
09593e9b
...
@@ -76,8 +76,9 @@ def start_controller_process(
...
@@ -76,8 +76,9 @@ def start_controller_process(
)
)
try
:
try
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
model_client
=
ModelTpClient
(
model_client
=
ModelTpClient
(
list
(
range
(
server_args
.
tp_size
))
,
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
,
server_args
,
server_args
,
port_args
.
model_port_args
[
0
],
port_args
.
model_port_args
[
0
],
model_overide_args
,
model_overide_args
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
09593e9b
...
@@ -246,12 +246,16 @@ class ModelRunner:
...
@@ -246,12 +246,16 @@ class ModelRunner:
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
if
server_args
.
nccl_init_addr
:
nccl_init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
else
:
nccl_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
init_distributed_environment
(
init_distributed_environment
(
backend
=
"nccl"
,
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
distributed_init_method
=
nccl_init_method
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
total_gpu_memory
=
get_available_gpu_memory
(
total_gpu_memory
=
get_available_gpu_memory
(
...
@@ -311,7 +315,7 @@ class ModelRunner:
...
@@ -311,7 +315,7 @@ class ModelRunner:
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
head_dim
=
self
.
model_config
.
head_dim
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
num_k
ey_value
_heads
//
self
.
tp_size
head_num
=
self
.
model_config
.
get_
num_k
v
_heads
(
self
.
tp_size
)
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
1
-
self
.
mem_fraction_static
...
@@ -324,7 +328,7 @@ class ModelRunner:
...
@@ -324,7 +328,7 @@ class ModelRunner:
if
self
.
max_total_num_tokens
<=
0
:
if
self
.
max_total_num_tokens
<=
0
:
raise
RuntimeError
(
raise
RuntimeError
(
"Not enough
t
memory. Please try to increase --mem-fraction-static."
"Not enough memory. Please try to increase --mem-fraction-static."
)
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
09593e9b
...
@@ -37,7 +37,8 @@ from sglang.srt.utils import (
...
@@ -37,7 +37,8 @@ from sglang.srt.utils import (
get_int_token_logit_bias
,
get_int_token_logit_bias
,
is_multimodal_model
,
is_multimodal_model
,
set_random_seed
,
set_random_seed
,
start_rpyc_process
,
start_rpyc_service_process
,
connect_rpyc_service
,
suppress_other_loggers
,
suppress_other_loggers
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -770,12 +771,17 @@ class ModelTpClient:
...
@@ -770,12 +771,17 @@ class ModelTpClient:
else
:
else
:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
# Launch model processes
# Launch model processes
rets
=
executor
.
map
(
if
server_args
.
nnodes
==
1
:
lambda
args
:
start_rpyc_process
(
*
args
),
self
.
procs
=
list
(
executor
.
map
(
[(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
lambda
args
:
start_rpyc_service_process
(
*
args
),
)
[(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
self
.
model_services
=
[
x
[
0
]
for
x
in
rets
]
))
self
.
procs
=
[
x
[
1
]
for
x
in
rets
]
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
else
:
addrs
=
[(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)]
self
.
model_services
=
list
(
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
))
# Init model
# Init model
def
init_model
(
i
):
def
init_model
(
i
):
...
@@ -787,7 +793,7 @@ class ModelTpClient:
...
@@ -787,7 +793,7 @@ class ModelTpClient:
model_overide_args
,
model_overide_args
,
)
)
self
.
model_servers
=
executor
.
map
(
init_model
,
range
(
self
.
tp_size
))
self
.
model_servers
=
list
(
executor
.
map
(
init_model
,
range
(
self
.
tp_size
))
)
# Wrap functions
# Wrap functions
def
async_wrap
(
func_name
):
def
async_wrap
(
func_name
):
...
...
python/sglang/srt/model_config.py
View file @
09593e9b
...
@@ -71,7 +71,11 @@ class ModelConfig:
...
@@ -71,7 +71,11 @@ class ModelConfig:
return
1
return
1
# For DBRX and MPT
# For DBRX and MPT
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
if
self
.
hf_config
.
model_type
in
[
"mpt"
]:
if
"kv_n_heads"
in
self
.
hf_config
.
attn_config
:
return
self
.
hf_config
.
attn_config
[
"kv_n_heads"
]
return
self
.
hf_config
.
num_attention_heads
if
self
.
hf_config
.
model_type
in
[
"dbrx"
]:
return
getattr
(
return
getattr
(
self
.
hf_config
.
attn_config
,
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
"kv_n_heads"
,
...
...
python/sglang/srt/server.py
View file @
09593e9b
...
@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import (
...
@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import (
from
sglang.srt.managers.controller.manager_single
import
(
from
sglang.srt.managers.controller.manager_single
import
(
start_controller_process
as
start_controller_process_single
,
start_controller_process
as
start_controller_process_single
,
)
)
from
sglang.srt.managers.controller.tp_worker
import
ModelTpService
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
...
@@ -50,9 +51,13 @@ from sglang.srt.utils import (
...
@@ -50,9 +51,13 @@ from sglang.srt.utils import (
allocate_init_ports
,
allocate_init_ports
,
assert_pkg_version
,
assert_pkg_version
,
enable_show_time_cost
,
enable_show_time_cost
,
send_addrs_to_rank_0
,
receive_addrs
,
start_rpyc_service_process
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
# Allocate ports
# Allocate ports
assert
server_args
.
tp_size
%
server_args
.
nnodes
==
0
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
port
,
server_args
.
additional_ports
,
server_args
.
additional_ports
,
server_args
.
tp_size
,
tp_size
_local
,
server_args
.
dp_size
,
server_args
.
dp_size
,
)
)
ports
=
server_args
.
additional_ports
ports
=
server_args
.
additional_ports
tp
=
server_args
.
tp_size
model_port_args
=
[]
model_port_args
=
[]
for
i
in
range
(
server_args
.
dp_size
):
for
i
in
range
(
server_args
.
dp_size
):
model_port_args
.
append
(
model_port_args
.
append
(
ModelPortArgs
(
ModelPortArgs
(
nccl_port
=
ports
[
3
+
i
*
(
tp
+
1
)],
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
model_tp_ports
=
ports
[
3
+
i
*
(
tp
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp
+
1
)],
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)],
)
)
)
)
port_args
=
PortArgs
(
port_args
=
PortArgs
(
...
@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args
=
model_port_args
,
model_port_args
=
model_port_args
,
)
)
# TODO multi-node dp is not supported
assert
not
(
server_args
.
dp_size
>
1
and
server_args
.
node_rank
is
not
None
)
if
server_args
.
nnodes
>
1
:
if
server_args
.
node_rank
!=
0
:
send_addrs_to_rank_0
(
model_port_args
[
0
],
server_args
)
else
:
receive_addrs
(
model_port_args
[
0
],
server_args
)
for
i
in
range
(
tp_size_local
):
start_rpyc_service_process
(
ModelTpService
,
model_port_args
[
0
].
model_tp_ports
[
i
])
if
server_args
.
node_rank
!=
0
:
print
(
"Listen for connections..."
)
while
True
:
pass
# Launch processes
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
pipe_router_reader
,
pipe_router_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_router_reader
,
pipe_router_writer
=
mp
.
Pipe
(
duplex
=
False
)
...
...
python/sglang/srt/server_args.py
View file @
09593e9b
...
@@ -56,6 +56,11 @@ class ServerArgs:
...
@@ -56,6 +56,11 @@ class ServerArgs:
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nnodes
:
int
=
1
node_rank
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
...
@@ -252,6 +257,24 @@ class ServerArgs:
...
@@ -252,6 +257,24 @@ class ServerArgs:
],
],
)
)
# Multi-node distributed serving args
parser
.
add_argument
(
"--nccl-init-addr"
,
type
=
str
,
help
=
"The nccl init address of multi-node server."
)
parser
.
add_argument
(
"--nnodes"
,
type
=
int
,
default
=
1
,
help
=
"Number of nodes"
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
# Optimization/debug options
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-flashinfer"
,
"--enable-flashinfer"
,
...
@@ -300,6 +323,7 @@ class ServerArgs:
...
@@ -300,6 +323,7 @@ class ServerArgs:
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
ModelPortArgs
:
class
ModelPortArgs
:
nccl_port
:
int
nccl_port
:
int
model_tp_ips
:
List
[
str
]
model_tp_ports
:
List
[
int
]
model_tp_ports
:
List
[
int
]
...
...
python/sglang/srt/utils.py
View file @
09593e9b
"""Common utilities."""
"""Common utilities."""
import
base64
import
base64
import
fcntl
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
os
import
random
import
random
import
socket
import
socket
import
struct
import
time
import
time
from
importlib.metadata
import
PackageNotFoundError
,
version
from
importlib.metadata
import
PackageNotFoundError
,
version
from
io
import
BytesIO
from
io
import
BytesIO
...
@@ -369,23 +371,7 @@ def load_image(image_file):
...
@@ -369,23 +371,7 @@ def load_image(image_file):
return
image
,
image_size
return
image
,
image_size
def
init_rpyc_service
(
service
:
rpyc
.
Service
,
port
:
int
):
def
connect_rpyc_service
(
host
,
port
):
t
=
ThreadedServer
(
service
=
service
,
port
=
port
,
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
t
.
start
()
def
connect_to_rpyc_service
(
port
,
host
=
"localhost"
):
time
.
sleep
(
1
)
repeat_count
=
0
repeat_count
=
0
while
repeat_count
<
20
:
while
repeat_count
<
20
:
try
:
try
:
...
@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
...
@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
},
},
)
)
break
break
except
ConnectionRefusedError
:
except
ConnectionRefusedError
as
e
:
time
.
sleep
(
1
)
time
.
sleep
(
1
)
repeat_count
+=
1
repeat_count
+=
1
if
repeat_count
==
20
:
if
repeat_count
==
20
:
raise
RuntimeError
(
"ini
t rpc e
nv error!
"
)
raise
RuntimeError
(
f
"Connec
t rp
y
c e
rror:
{
e
}
"
)
return
con
.
root
return
con
.
root
def
start_rpyc_process
(
service
:
rpyc
.
Service
,
port
:
int
):
def
start_rpyc_service
(
service
:
rpyc
.
Service
,
port
:
int
):
# Return the proxy and the process
t
=
ThreadedServer
(
proc
=
multiprocessing
.
Process
(
target
=
init_rpyc_service
,
args
=
(
service
,
port
))
service
=
service
,
port
=
port
,
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
t
.
start
()
def
start_rpyc_service_process
(
service
:
rpyc
.
Service
,
port
:
int
):
proc
=
multiprocessing
.
Process
(
target
=
start_rpyc_service
,
args
=
(
service
,
port
))
proc
.
start
()
proc
.
start
()
proxy
=
connect_to_rpyc_service
(
port
)
return
proc
assert
proc
.
is_alive
()
return
proxy
,
proc
def
suppress_other_loggers
():
def
suppress_other_loggers
():
...
@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
...
@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
)
response
=
await
call_next
(
request
)
response
=
await
call_next
(
request
)
return
response
return
response
def
get_ip_address
(
ifname
):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
ip_address
=
fcntl
.
ioctl
(
s
.
fileno
(),
0x8915
,
# SIOCGIFADDR
struct
.
pack
(
'256s'
,
bytes
(
ifname
[:
15
],
'utf-8'
))
)[
20
:
24
]
return
socket
.
inet_ntoa
(
ip_address
)
def
send_addrs_to_rank_0
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
ip_addr
=
[
int
(
x
)
for
x
in
ip_addr
.
split
(
"."
)]
addrs_tensor
=
torch
.
tensor
(
ip_addr
+
model_port_args
.
model_tp_ports
,
dtype
=
torch
.
int
)
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
dist
.
send
(
addrs_tensor
,
dst
=
0
)
print
(
f
"Node
{
server_args
.
node_rank
}
sent: ip_address
{
ip_addr
}
and ports
{
model_port_args
.
model_tp_ports
}
"
)
dist
.
barrier
()
dist
.
destroy_process_group
()
def
receive_addrs
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
for
src_rank
in
range
(
1
,
server_args
.
nnodes
):
tensor
=
torch
.
zeros
(
4
+
num_tp_ports
,
dtype
=
torch
.
int
)
dist
.
recv
(
tensor
,
src
=
src_rank
)
ip
=
"."
.
join
([
str
(
x
)
for
x
in
tensor
[:
4
].
tolist
()])
ports
=
tensor
[
4
:].
tolist
()
model_port_args
.
model_tp_ips
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
[
ip
]
*
num_tp_ports
model_port_args
.
model_tp_ports
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
ports
print
(
f
"Node 0 received from rank
{
src_rank
}
:
{
tensor
.
tolist
()
}
"
)
dist
.
barrier
()
dist
.
destroy_process_group
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment