Unverified Commit 11383cec authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[PP] Add pipeline parallelism (#5724)

parent e97e57e6
......@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
"""Common utilities."""
import base64
import builtins
import ctypes
......@@ -414,16 +415,40 @@ class LayerFn(Protocol):
def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
pp_rank: Optional[int] = None,
pp_size: Optional[int] = None,
prefix: str = "",
return_tuple: bool = False,
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function"""
# circula imports
from sglang.srt.distributed import get_pp_indices
from sglang.srt.layers.utils import PPMissingLayer
assert not pp_size or num_hidden_layers >= pp_size
start_layer, end_layer = (
get_pp_indices(
num_hidden_layers,
pp_rank,
pp_size,
)
if pp_rank is not None and pp_size is not None
else (0, num_hidden_layers)
)
modules = torch.nn.ModuleList(
[
[PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
+ [
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
for idx in range(num_hidden_layers)
for idx in range(start_layer, end_layer)
]
+ [
PPMissingLayer(return_tuple=return_tuple)
for _ in range(end_layer, num_hidden_layers)
]
)
return modules
if pp_rank is None or pp_size is None:
return modules
return modules, start_layer, end_layer
def set_random_seed(seed: int) -> None:
......@@ -877,7 +902,7 @@ def broadcast_pyobj(
"cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
)
if rank == 0:
if rank == src:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.broadcast(tensor_size, src=src, group=dist_group)
......@@ -909,6 +934,50 @@ def broadcast_pyobj(
return data
def point_to_point_pyobj(
data: List[Any],
rank: int,
group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
dst: int = 1,
):
"""Send data from src to dst in group."""
if rank == src:
if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.send(tensor_size, dst=dst, group=group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
)
tensor_size = torch.tensor([size], dtype=torch.long)
dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group)
return data
elif rank == dst:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(size, dtype=torch.uint8)
dist.recv(tensor_data, src=src, group=group)
serialized_data = bytes(tensor_data.cpu().numpy())
data = pickle.loads(serialized_data)
return data
# Other ranks in pp_group do nothing
return []
step_counter = 0
......@@ -1732,6 +1801,13 @@ def configure_ipv6(dist_init_addr):
return port, host
def rank0_log(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0:
logger.info(msg)
def rank0_print(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank
......
......@@ -770,6 +770,34 @@ def run_bench_offline_throughput(model, other_args):
return output_throughput
def run_bench_one_batch_server(
model,
base_url,
server_args,
bench_args,
other_server_args,
simulate_spec_acc_lens=None,
):
from sglang.bench_one_batch_server import run_benchmark
if simulate_spec_acc_lens is not None:
env = {**os.environ, "SIMULATE_ACC_LEN": str(simulate_spec_acc_lens)}
else:
env = None
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
env=env,
)
try:
run_benchmark(server_args=server_args, bench_args=bench_args)
finally:
kill_process_tree(process.pid)
def lcs(X, Y):
m = len(X)
n = len(Y)
......
......@@ -96,6 +96,8 @@ suites = {
"per-commit-8-gpu": [
TestFile("test_local_attn.py", 250),
TestFile("test_full_deepseek_v3.py", 250),
TestFile("test_fa3.py", 30),
TestFile("test_pp_single_node.py", 150),
],
"nightly": [
TestFile("test_nightly_gsm8k_eval.py"),
......
"""
Usage:
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
"""
import os
import time
import unittest
from types import SimpleNamespace
import requests
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
run_bench_one_batch_server,
)
class TestPPAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
# These config helps find a leak.
os.environ["SGLANG_IS_IN_CI"] = "1"
cls.base_url = "http://127.0.0.1:23333"
cls.process = popen_launch_server(
DEFAULT_MODEL_NAME_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--pp-size",
4,
"--disable-overlap-schedule",
"--chunked-prefill-size",
256,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.75)
# Wait a little bit so that the memory check happens.
time.sleep(5)
# class TestPPAccuracyFlashInfer(unittest.TestCase):
# @classmethod
# def setUpClass(cls):
# # These config helps find a leak.
# os.environ["SGLANG_IS_IN_CI"] = "1"
# cls.base_url = "http://127.0.0.1:23333"
# cls.process = popen_launch_server(
# DEFAULT_MODEL_NAME_FOR_TEST,
# cls.base_url,
# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args=[
# "--pp-size",
# 4,
# "--disable-overlap-schedule",
# "--attention-backend",
# "flashinfer",
# "--chunked-prefill-size",
# 256,
# ],
# )
#
# @classmethod
# def tearDownClass(cls):
# kill_process_tree(cls.process.pid)
#
# def test_gsm8k(self):
# args = SimpleNamespace(
# num_shots=5,
# data_path=None,
# num_questions=200,
# max_new_tokens=512,
# parallel=128,
# host="http://127.0.0.1",
# port=int(self.base_url.split(":")[-1]),
# )
# metrics = run_eval(args)
# print(f"{metrics=}")
#
# self.assertGreater(metrics["accuracy"], 0.75)
# # Wait a little bit so that the memory check happens.
# time.sleep(5)
class TestFixedBugs(unittest.TestCase):
def test_chunked_prefill_with_small_bs(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
server_args = ServerArgs(model_path=model)
bench_args = OneBatchBenchArgs(
batch_size=(1,),
input_len=(1,),
output_len=(1,),
base_url=DEFAULT_URL_FOR_TEST,
)
other_server_args = [
"--tp-size",
2,
"--pp-size",
2,
"--disable-overlap-schedule",
"--chunked-prefill",
256,
"--max-running-requests",
2,
]
run_bench_one_batch_server(
model,
DEFAULT_URL_FOR_TEST,
server_args,
bench_args,
other_server_args,
)
if __name__ == "__main__":
unittest.main()
......@@ -147,6 +147,8 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
gpu_id=0,
tp_rank=0,
tp_size=1,
pp_rank=0,
pp_size=1,
nccl_port=12435,
server_args=ServerArgs(
model_path=self.model_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment