Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11383cec
"vscode:/vscode.git/clone" did not exist on "12cb115d381cc19605c2fd3aa696ddf550f480de"
Unverified
Commit
11383cec
authored
Apr 30, 2025
by
Ying Sheng
Committed by
GitHub
Apr 30, 2025
Browse files
[PP] Add pipeline parallelism (#5724)
parent
e97e57e6
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
255 additions
and
4 deletions
+255
-4
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+80
-4
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+28
-0
test/srt/run_suite.py
test/srt/run_suite.py
+2
-0
test/srt/test_pp_single_node.py
test/srt/test_pp_single_node.py
+143
-0
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+2
-0
No files found.
python/sglang/srt/utils.py
View file @
11383cec
...
...
@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
"""Common utilities."""
import
base64
import
builtins
import
ctypes
...
...
@@ -414,16 +415,40 @@ class LayerFn(Protocol):
def
make_layers
(
num_hidden_layers
:
int
,
layer_fn
:
LayerFn
,
pp_rank
:
Optional
[
int
]
=
None
,
pp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
return_tuple
:
bool
=
False
,
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
"""Make a list of layers with the given layer function"""
# circula imports
from
sglang.srt.distributed
import
get_pp_indices
from
sglang.srt.layers.utils
import
PPMissingLayer
assert
not
pp_size
or
num_hidden_layers
>=
pp_size
start_layer
,
end_layer
=
(
get_pp_indices
(
num_hidden_layers
,
pp_rank
,
pp_size
,
)
if
pp_rank
is
not
None
and
pp_size
is
not
None
else
(
0
,
num_hidden_layers
)
)
modules
=
torch
.
nn
.
ModuleList
(
[
[
PPMissingLayer
(
return_tuple
=
return_tuple
)
for
_
in
range
(
start_layer
)]
+
[
maybe_offload_to_cpu
(
layer_fn
(
idx
=
idx
,
prefix
=
add_prefix
(
idx
,
prefix
)))
for
idx
in
range
(
num_hidden_layers
)
for
idx
in
range
(
start_layer
,
end_layer
)
]
+
[
PPMissingLayer
(
return_tuple
=
return_tuple
)
for
_
in
range
(
end_layer
,
num_hidden_layers
)
]
)
return
modules
if
pp_rank
is
None
or
pp_size
is
None
:
return
modules
return
modules
,
start_layer
,
end_layer
def
set_random_seed
(
seed
:
int
)
->
None
:
...
...
@@ -877,7 +902,7 @@ def broadcast_pyobj(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
force_cpu_device
else
"cpu"
)
if
rank
==
0
:
if
rank
==
src
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
)
dist
.
broadcast
(
tensor_size
,
src
=
src
,
group
=
dist_group
)
...
...
@@ -909,6 +934,50 @@ def broadcast_pyobj(
return
data
def
point_to_point_pyobj
(
data
:
List
[
Any
],
rank
:
int
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
dst
:
int
=
1
,
):
"""Send data from src to dst in group."""
if
rank
==
src
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
else
:
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
tensor_data
=
torch
.
ByteTensor
(
np
.
frombuffer
(
serialized_data
,
dtype
=
np
.
uint8
)
)
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
send
(
tensor_size
,
dst
=
dst
,
group
=
group
)
dist
.
send
(
tensor_data
,
dst
=
dst
,
group
=
group
)
return
data
elif
rank
==
dst
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
recv
(
tensor_size
,
src
=
src
,
group
=
group
)
size
=
tensor_size
.
item
()
if
size
==
0
:
return
[]
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
dist
.
recv
(
tensor_data
,
src
=
src
,
group
=
group
)
serialized_data
=
bytes
(
tensor_data
.
cpu
().
numpy
())
data
=
pickle
.
loads
(
serialized_data
)
return
data
# Other ranks in pp_group do nothing
return
[]
step_counter
=
0
...
...
@@ -1732,6 +1801,13 @@ def configure_ipv6(dist_init_addr):
return
port
,
host
def
rank0_log
(
msg
:
str
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
if
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
msg
)
def
rank0_print
(
msg
:
str
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
...
...
python/sglang/test/test_utils.py
View file @
11383cec
...
...
@@ -770,6 +770,34 @@ def run_bench_offline_throughput(model, other_args):
return
output_throughput
def
run_bench_one_batch_server
(
model
,
base_url
,
server_args
,
bench_args
,
other_server_args
,
simulate_spec_acc_lens
=
None
,
):
from
sglang.bench_one_batch_server
import
run_benchmark
if
simulate_spec_acc_lens
is
not
None
:
env
=
{
**
os
.
environ
,
"SIMULATE_ACC_LEN"
:
str
(
simulate_spec_acc_lens
)}
else
:
env
=
None
process
=
popen_launch_server
(
model
,
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
other_server_args
,
env
=
env
,
)
try
:
run_benchmark
(
server_args
=
server_args
,
bench_args
=
bench_args
)
finally
:
kill_process_tree
(
process
.
pid
)
def
lcs
(
X
,
Y
):
m
=
len
(
X
)
n
=
len
(
Y
)
...
...
test/srt/run_suite.py
View file @
11383cec
...
...
@@ -96,6 +96,8 @@ suites = {
"per-commit-8-gpu"
:
[
TestFile
(
"test_local_attn.py"
,
250
),
TestFile
(
"test_full_deepseek_v3.py"
,
250
),
TestFile
(
"test_fa3.py"
,
30
),
TestFile
(
"test_pp_single_node.py"
,
150
),
],
"nightly"
:
[
TestFile
(
"test_nightly_gsm8k_eval.py"
),
...
...
test/srt/test_pp_single_node.py
0 → 100644
View file @
11383cec
"""
Usage:
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
"""
import
os
import
time
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.bench_one_batch_server
import
BenchArgs
as
OneBatchBenchArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.runners
import
DEFAULT_PROMPTS
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
run_bench_one_batch_server
,
)
class
TestPPAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
# These config helps find a leak.
os
.
environ
[
"SGLANG_IS_IN_CI"
]
=
"1"
cls
.
base_url
=
"http://127.0.0.1:23333"
cls
.
process
=
popen_launch_server
(
DEFAULT_MODEL_NAME_FOR_TEST
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--pp-size"
,
4
,
"--disable-overlap-schedule"
,
"--chunked-prefill-size"
,
256
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.75
)
# Wait a little bit so that the memory check happens.
time
.
sleep
(
5
)
# class TestPPAccuracyFlashInfer(unittest.TestCase):
# @classmethod
# def setUpClass(cls):
# # These config helps find a leak.
# os.environ["SGLANG_IS_IN_CI"] = "1"
# cls.base_url = "http://127.0.0.1:23333"
# cls.process = popen_launch_server(
# DEFAULT_MODEL_NAME_FOR_TEST,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args=[
# "--pp-size",
# 4,
# "--disable-overlap-schedule",
# "--attention-backend",
# "flashinfer",
# "--chunked-prefill-size",
# 256,
# ],
# )
#
# @classmethod
# def tearDownClass(cls):
# kill_process_tree(cls.process.pid)
#
# def test_gsm8k(self):
# args = SimpleNamespace(
# num_shots=5,
# data_path=None,
# num_questions=200,
# max_new_tokens=512,
# parallel=128,
# host="http://127.0.0.1",
# port=int(self.base_url.split(":")[-1]),
# )
# metrics = run_eval(args)
# print(f"{metrics=}")
#
# self.assertGreater(metrics["accuracy"], 0.75)
# # Wait a little bit so that the memory check happens.
# time.sleep(5)
class
TestFixedBugs
(
unittest
.
TestCase
):
def
test_chunked_prefill_with_small_bs
(
self
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
server_args
=
ServerArgs
(
model_path
=
model
)
bench_args
=
OneBatchBenchArgs
(
batch_size
=
(
1
,),
input_len
=
(
1
,),
output_len
=
(
1
,),
base_url
=
DEFAULT_URL_FOR_TEST
,
)
other_server_args
=
[
"--tp-size"
,
2
,
"--pp-size"
,
2
,
"--disable-overlap-schedule"
,
"--chunked-prefill"
,
256
,
"--max-running-requests"
,
2
,
]
run_bench_one_batch_server
(
model
,
DEFAULT_URL_FOR_TEST
,
server_args
,
bench_args
,
other_server_args
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/test_vlm_accuracy.py
View file @
11383cec
...
...
@@ -147,6 +147,8 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
gpu_id
=
0
,
tp_rank
=
0
,
tp_size
=
1
,
pp_rank
=
0
,
pp_size
=
1
,
nccl_port
=
12435
,
server_args
=
ServerArgs
(
model_path
=
self
.
model_path
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment