Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
09593e9b
Unverified
Commit
09593e9b
authored
Jun 17, 2024
by
Ying Sheng
Committed by
GitHub
Jun 17, 2024
Browse files
Multi-node Tensor Parallelism (#550)
Co-authored-by:
Lianmin Zheng
<
lianminzheng@gmail.com
>
parent
53a7ebd8
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
167 additions
and
46 deletions
+167
-46
benchmark/latency_throughput/README.md
benchmark/latency_throughput/README.md
+2
-2
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+1
-1
python/sglang/launch_server.py
python/sglang/launch_server.py
+2
-1
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+2
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-3
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+14
-8
python/sglang/srt/model_config.py
python/sglang/srt/model_config.py
+5
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+25
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+24
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+85
-25
No files found.
benchmark/latency_throughput/README.md
View file @
09593e9b
...
...
@@ -20,7 +20,7 @@ python3 bench_throughput.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat
```
# run synthetic
python3
synthetic_benchmark
.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
python3
bench_throughput
.py --backend srt --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
```
...
...
@@ -36,7 +36,7 @@ python3 bench_throughput.py --backend vllm --tokenizer meta-llama/Llama-2-7b-cha
```
# run synthetic
python3
synthetic_benchmark
.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
python3
bench_throughput
.py --backend vllm --tokenizer meta-llama/Llama-2-7b-chat-hf --num-prompt 1000 --request-rate 100 --input-len 1024 --output-len 256 --port 30000
```
...
...
benchmark/latency_throughput/test_latency.py
View file @
09593e9b
...
...
@@ -24,7 +24,7 @@ if __name__ == "__main__":
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
a
=
random
.
randint
(
0
,
1
<<
20
)
a
=
20
max_new_tokens
=
256
prompt
=
f
"
{
a
,
}
"
...
...
python/sglang/launch_server.py
View file @
09593e9b
...
...
@@ -2,7 +2,8 @@
import
argparse
from
sglang.srt.server
import
ServerArgs
,
launch_server
from
sglang.srt.server
import
launch_server
from
sglang.srt.server_args
import
ServerArgs
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
09593e9b
...
...
@@ -76,8 +76,9 @@ def start_controller_process(
)
try
:
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
model_client
=
ModelTpClient
(
list
(
range
(
server_args
.
tp_size
))
,
[
i
for
_
in
range
(
server_args
.
nnodes
)
for
i
in
range
(
tp_size_local
)]
,
server_args
,
port_args
.
model_port_args
[
0
],
model_overide_args
,
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
09593e9b
...
...
@@ -246,12 +246,16 @@ class ModelRunner:
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu_id=
{
self
.
gpu_id
}
] Init nccl begin."
)
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
if
server_args
.
nccl_init_addr
:
nccl_init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
else
:
nccl_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
init_distributed_environment
(
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
rank
=
self
.
tp_rank
,
local_rank
=
self
.
gpu_id
,
distributed_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
distributed_init_method
=
nccl_init_method
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
total_gpu_memory
=
get_available_gpu_memory
(
...
...
@@ -311,7 +315,7 @@ class ModelRunner:
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
num_k
ey_value
_heads
//
self
.
tp_size
head_num
=
self
.
model_config
.
get_
num_k
v
_heads
(
self
.
tp_size
)
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
...
...
@@ -324,7 +328,7 @@ class ModelRunner:
if
self
.
max_total_num_tokens
<=
0
:
raise
RuntimeError
(
"Not enough
t
memory. Please try to increase --mem-fraction-static."
"Not enough memory. Please try to increase --mem-fraction-static."
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
09593e9b
...
...
@@ -37,7 +37,8 @@ from sglang.srt.utils import (
get_int_token_logit_bias
,
is_multimodal_model
,
set_random_seed
,
start_rpyc_process
,
start_rpyc_service_process
,
connect_rpyc_service
,
suppress_other_loggers
,
)
from
sglang.utils
import
get_exception_traceback
...
...
@@ -770,12 +771,17 @@ class ModelTpClient:
else
:
with
ThreadPoolExecutor
(
self
.
tp_size
)
as
executor
:
# Launch model processes
rets
=
executor
.
map
(
lambda
args
:
start_rpyc_process
(
*
args
),
if
server_args
.
nnodes
==
1
:
self
.
procs
=
list
(
executor
.
map
(
lambda
args
:
start_rpyc_service_process
(
*
args
),
[(
ModelTpService
,
p
)
for
p
in
model_port_args
.
model_tp_ports
],
)
self
.
model_services
=
[
x
[
0
]
for
x
in
rets
]
self
.
procs
=
[
x
[
1
]
for
x
in
rets
]
))
addrs
=
[(
"localhost"
,
p
)
for
p
in
model_port_args
.
model_tp_ports
]
else
:
addrs
=
[(
ip
,
port
)
for
ip
,
port
in
zip
(
model_port_args
.
model_tp_ips
,
model_port_args
.
model_tp_ports
)]
self
.
model_services
=
list
(
executor
.
map
(
lambda
args
:
connect_rpyc_service
(
*
args
),
addrs
))
# Init model
def
init_model
(
i
):
...
...
@@ -787,7 +793,7 @@ class ModelTpClient:
model_overide_args
,
)
self
.
model_servers
=
executor
.
map
(
init_model
,
range
(
self
.
tp_size
))
self
.
model_servers
=
list
(
executor
.
map
(
init_model
,
range
(
self
.
tp_size
))
)
# Wrap functions
def
async_wrap
(
func_name
):
...
...
python/sglang/srt/model_config.py
View file @
09593e9b
...
...
@@ -71,7 +71,11 @@ class ModelConfig:
return
1
# For DBRX and MPT
if
self
.
hf_config
.
model_type
in
[
"dbrx"
,
"mpt"
]:
if
self
.
hf_config
.
model_type
in
[
"mpt"
]:
if
"kv_n_heads"
in
self
.
hf_config
.
attn_config
:
return
self
.
hf_config
.
attn_config
[
"kv_n_heads"
]
return
self
.
hf_config
.
num_attention_heads
if
self
.
hf_config
.
model_type
in
[
"dbrx"
]:
return
getattr
(
self
.
hf_config
.
attn_config
,
"kv_n_heads"
,
...
...
python/sglang/srt/server.py
View file @
09593e9b
...
...
@@ -35,6 +35,7 @@ from sglang.srt.managers.controller.manager_multi import (
from
sglang.srt.managers.controller.manager_single
import
(
start_controller_process
as
start_controller_process_single
,
)
from
sglang.srt.managers.controller.tp_worker
import
ModelTpService
from
sglang.srt.managers.detokenizer_manager
import
start_detokenizer_process
from
sglang.srt.managers.io_struct
import
GenerateReqInput
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
...
...
@@ -50,9 +51,13 @@ from sglang.srt.utils import (
allocate_init_ports
,
assert_pkg_version
,
enable_show_time_cost
,
send_addrs_to_rank_0
,
receive_addrs
,
start_rpyc_service_process
,
)
from
sglang.utils
import
get_exception_traceback
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -151,21 +156,23 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
load_chat_template_for_openai_api
(
server_args
.
chat_template
)
# Allocate ports
assert
server_args
.
tp_size
%
server_args
.
nnodes
==
0
tp_size_local
=
server_args
.
tp_size
//
server_args
.
nnodes
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
,
server_args
.
tp_size
,
tp_size
_local
,
server_args
.
dp_size
,
)
ports
=
server_args
.
additional_ports
tp
=
server_args
.
tp_size
model_port_args
=
[]
for
i
in
range
(
server_args
.
dp_size
):
model_port_args
.
append
(
ModelPortArgs
(
nccl_port
=
ports
[
3
+
i
*
(
tp
+
1
)],
model_tp_ports
=
ports
[
3
+
i
*
(
tp
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp
+
1
)],
nccl_port
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)],
model_tp_ips
=
[
None
]
*
tp_size_local
,
model_tp_ports
=
ports
[
3
+
i
*
(
tp_size_local
+
1
)
+
1
:
3
+
(
i
+
1
)
*
(
tp_size_local
+
1
)],
)
)
port_args
=
PortArgs
(
...
...
@@ -175,6 +182,20 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
model_port_args
=
model_port_args
,
)
# TODO multi-node dp is not supported
assert
not
(
server_args
.
dp_size
>
1
and
server_args
.
node_rank
is
not
None
)
if
server_args
.
nnodes
>
1
:
if
server_args
.
node_rank
!=
0
:
send_addrs_to_rank_0
(
model_port_args
[
0
],
server_args
)
else
:
receive_addrs
(
model_port_args
[
0
],
server_args
)
for
i
in
range
(
tp_size_local
):
start_rpyc_service_process
(
ModelTpService
,
model_port_args
[
0
].
model_tp_ports
[
i
])
if
server_args
.
node_rank
!=
0
:
print
(
"Listen for connections..."
)
while
True
:
pass
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_overide_args
)
pipe_router_reader
,
pipe_router_writer
=
mp
.
Pipe
(
duplex
=
False
)
...
...
python/sglang/srt/server_args.py
View file @
09593e9b
...
...
@@ -56,6 +56,11 @@ class ServerArgs:
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nnodes
:
int
=
1
node_rank
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
...
...
@@ -252,6 +257,24 @@ class ServerArgs:
],
)
# Multi-node distributed serving args
parser
.
add_argument
(
"--nccl-init-addr"
,
type
=
str
,
help
=
"The nccl init address of multi-node server."
)
parser
.
add_argument
(
"--nnodes"
,
type
=
int
,
default
=
1
,
help
=
"Number of nodes"
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
# Optimization/debug options
parser
.
add_argument
(
"--enable-flashinfer"
,
...
...
@@ -300,6 +323,7 @@ class ServerArgs:
@
dataclasses
.
dataclass
class
ModelPortArgs
:
nccl_port
:
int
model_tp_ips
:
List
[
str
]
model_tp_ports
:
List
[
int
]
...
...
python/sglang/srt/utils.py
View file @
09593e9b
"""Common utilities."""
import
base64
import
fcntl
import
logging
import
multiprocessing
import
os
import
random
import
socket
import
struct
import
time
from
importlib.metadata
import
PackageNotFoundError
,
version
from
io
import
BytesIO
...
...
@@ -369,23 +371,7 @@ def load_image(image_file):
return
image
,
image_size
def
init_rpyc_service
(
service
:
rpyc
.
Service
,
port
:
int
):
t
=
ThreadedServer
(
service
=
service
,
port
=
port
,
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
t
.
start
()
def
connect_to_rpyc_service
(
port
,
host
=
"localhost"
):
time
.
sleep
(
1
)
def
connect_rpyc_service
(
host
,
port
):
repeat_count
=
0
while
repeat_count
<
20
:
try
:
...
...
@@ -399,22 +385,33 @@ def connect_to_rpyc_service(port, host="localhost"):
},
)
break
except
ConnectionRefusedError
:
except
ConnectionRefusedError
as
e
:
time
.
sleep
(
1
)
repeat_count
+=
1
if
repeat_count
==
20
:
raise
RuntimeError
(
"ini
t rpc e
nv error!
"
)
raise
RuntimeError
(
f
"Connec
t rp
y
c e
rror:
{
e
}
"
)
return
con
.
root
def
start_rpyc_process
(
service
:
rpyc
.
Service
,
port
:
int
):
# Return the proxy and the process
proc
=
multiprocessing
.
Process
(
target
=
init_rpyc_service
,
args
=
(
service
,
port
))
def
start_rpyc_service
(
service
:
rpyc
.
Service
,
port
:
int
):
t
=
ThreadedServer
(
service
=
service
,
port
=
port
,
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
3600
,
},
)
t
.
logger
.
setLevel
(
logging
.
WARN
)
t
.
start
()
def
start_rpyc_service_process
(
service
:
rpyc
.
Service
,
port
:
int
):
proc
=
multiprocessing
.
Process
(
target
=
start_rpyc_service
,
args
=
(
service
,
port
))
proc
.
start
()
proxy
=
connect_to_rpyc_service
(
port
)
assert
proc
.
is_alive
()
return
proxy
,
proc
return
proc
def
suppress_other_loggers
():
...
...
@@ -487,3 +484,66 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
response
=
await
call_next
(
request
)
return
response
def
get_ip_address
(
ifname
):
"""
Get the IP address of a network interface.
:param ifname: Name of the network interface (e.g., 'eth0')
:return: IP address of the network interface
"""
s
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_DGRAM
)
ip_address
=
fcntl
.
ioctl
(
s
.
fileno
(),
0x8915
,
# SIOCGIFADDR
struct
.
pack
(
'256s'
,
bytes
(
ifname
[:
15
],
'utf-8'
))
)[
20
:
24
]
return
socket
.
inet_ntoa
(
ip_address
)
def
send_addrs_to_rank_0
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
!=
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
ip_addr
=
[
int
(
x
)
for
x
in
ip_addr
.
split
(
"."
)]
addrs_tensor
=
torch
.
tensor
(
ip_addr
+
model_port_args
.
model_tp_ports
,
dtype
=
torch
.
int
)
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
dist
.
send
(
addrs_tensor
,
dst
=
0
)
print
(
f
"Node
{
server_args
.
node_rank
}
sent: ip_address
{
ip_addr
}
and ports
{
model_port_args
.
model_tp_ports
}
"
)
dist
.
barrier
()
dist
.
destroy_process_group
()
def
receive_addrs
(
model_port_args
,
server_args
):
assert
server_args
.
node_rank
==
0
and
server_args
.
dp_size
==
1
import
torch.distributed
as
dist
ifname
=
os
.
environ
.
get
(
"SGLANG_SOCKET_IFNAME"
,
os
.
environ
.
get
(
"NCCL_SOCKET_IFNAME"
,
"eth0"
))
ip_addr
=
get_ip_address
(
ifname
)
num_tp_ports
=
server_args
.
tp_size
//
server_args
.
nnodes
model_port_args
.
model_tp_ips
[:
num_tp_ports
]
=
[
ip_addr
]
*
num_tp_ports
init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
dist
.
init_process_group
(
backend
=
"gloo"
,
init_method
=
init_method
,
rank
=
server_args
.
node_rank
,
world_size
=
server_args
.
nnodes
)
for
src_rank
in
range
(
1
,
server_args
.
nnodes
):
tensor
=
torch
.
zeros
(
4
+
num_tp_ports
,
dtype
=
torch
.
int
)
dist
.
recv
(
tensor
,
src
=
src_rank
)
ip
=
"."
.
join
([
str
(
x
)
for
x
in
tensor
[:
4
].
tolist
()])
ports
=
tensor
[
4
:].
tolist
()
model_port_args
.
model_tp_ips
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
[
ip
]
*
num_tp_ports
model_port_args
.
model_tp_ports
[
num_tp_ports
*
src_rank
:
num_tp_ports
*
(
src_rank
+
1
)]
=
ports
print
(
f
"Node 0 received from rank
{
src_rank
}
:
{
tensor
.
tolist
()
}
"
)
dist
.
barrier
()
dist
.
destroy_process_group
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment