Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
......@@ -74,6 +74,16 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
- label: Pipeline Parallelism Test
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
......
......@@ -5,6 +5,7 @@ import pytest
import torch
from vllm import SamplingParams
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from ..utils import wait_for_gpu_memory_to_clear
......@@ -23,8 +24,11 @@ class MockEngine:
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
async def step_async(self):
async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
self.step_calls += 1
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
......@@ -32,6 +36,9 @@ class MockEngine:
async def process_model_inputs_async(self, *args, **kwargs):
pass
async def stop_remote_worker_execution_loop_async(self):
pass
def generate(self, request_id):
self.request_id = request_id
......@@ -41,6 +48,7 @@ class MockEngine:
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
print(f'Request calls: {self.add_request_calls}')
async def add_request_async(self, **kwargs):
self.add_request_calls += 1
......@@ -53,6 +61,9 @@ class MockEngine:
def has_unfinished_requests(self):
return self.request_id is not None
def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
return self.request_id is not None
class MockAsyncLLMEngine(AsyncLLMEngine):
......@@ -76,6 +87,7 @@ async def test_new_requests_event():
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls >= 2
await asyncio.sleep(0.001)
......
......@@ -4,7 +4,7 @@ import pytest
# and debugging.
import ray
from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
......@@ -12,7 +12,7 @@ MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
......
......@@ -56,8 +56,8 @@ def test_chunked_prefill_recompute(
max_num_seqs=max_num_seqs,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
......@@ -91,10 +91,10 @@ def test_preemption(
disable_log_stats=False,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
check_outputs_equal(
outputs_0_lst=hf_outputs,
......@@ -147,10 +147,10 @@ def test_swap(
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
......@@ -214,8 +214,8 @@ def test_swap_infeasible(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
......@@ -252,8 +252,8 @@ def test_preemption_infeasible(
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang.
for req_output in req_outputs:
......
......@@ -32,7 +32,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
(r + 1) for r in range(tp_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank]
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)
......@@ -60,7 +60,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
for r in range(tp_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank]
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)
......@@ -91,7 +91,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}
if rank == 0:
if (rank % tp_size) == 0:
broadcast_tensor_dict(test_dict, src=0)
else:
recv_dict = broadcast_tensor_dict(src=0)
......@@ -184,3 +184,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
send_recv_test_worker, send_recv_tensor_dict_test_worker,
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target)
import os
import openai # use the official client for correctness check
import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer
# downloading lora to test lora requests
# any model with a chat template should work here
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
EAGER_MODE = bool(int(os.getenv("EAGER_MODE", 0)))
CHUNKED_PREFILL = bool(int(os.getenv("CHUNKED_PREFILL", 0)))
TP_SIZE = int(os.getenv("TP_SIZE", 1))
PP_SIZE = int(os.getenv("PP_SIZE", 1))
pytestmark = pytest.mark.asyncio
@pytest.fixture(scope="module")
def ray_ctx():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def server(ray_ctx):
args = [
"--model",
MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
str(TP_SIZE),
"--distributed-executor-backend",
"ray",
]
if CHUNKED_PREFILL:
args += [
"--enable-chunked-prefill",
]
if EAGER_MODE:
args += [
"--enforce-eager",
]
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE)
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert completion.choices[0].text is not None and len(
completion.choices[0].text) >= 5
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
......@@ -32,7 +32,7 @@ def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
......@@ -86,7 +86,7 @@ def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
stop_checker=stop_checker,
......@@ -148,7 +148,7 @@ def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
......@@ -215,7 +215,7 @@ def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
output_processor = MultiStepOutputProcessor(
detokenizer=detokenizer,
scheduler=scheduler,
scheduler=[scheduler],
seq_counter=seq_counter,
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
stop_checker=stop_checker,
......
......@@ -14,7 +14,7 @@ import torch
from huggingface_hub import snapshot_download
from openai import BadRequestError
from ...utils import RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
......@@ -77,7 +77,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
......
......@@ -16,7 +16,7 @@ from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
......@@ -79,7 +79,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
......
......@@ -5,14 +5,14 @@ import openai
import pytest
import ray
from ...utils import RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
......
......@@ -6,7 +6,7 @@ import ray
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from ...utils import RemoteOpenAIServer
from ...utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
......@@ -22,7 +22,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
......
......@@ -24,13 +24,13 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def ray_ctx():
ray.init()
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def server():
def server(ray_ctx):
return RemoteOpenAIServer([
"--model",
MODEL_NAME,
......
......@@ -54,9 +54,9 @@ def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
return new_execute_model
def zero_kv_cache(cache_engine: CacheEngine):
assert cache_engine.gpu_cache
for key_blocks, value_blocks in cache_engine.gpu_cache:
def zero_kv_cache(cache_engine: List[CacheEngine]):
assert cache_engine[0].gpu_cache
for key_blocks, value_blocks in cache_engine[0].gpu_cache:
key_blocks.zero_()
value_blocks.zero_()
......
......@@ -22,7 +22,7 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup
from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer
# yapf conflicts with isort for this docstring
......@@ -220,6 +220,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
json.dumps(model_loader_extra_config),
]
ray.init(runtime_env={"working_dir": VLLM_PATH})
server = RemoteOpenAIServer(openai_args)
print("Server ready.")
......
......@@ -49,7 +49,6 @@ class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
@ray.remote(num_gpus=1)
class _RemoteRunner:
def __init__(self, cli_args: List[str], *, wait_url: str,
......@@ -92,7 +91,11 @@ class RemoteOpenAIServer:
if hasattr(self, "proc"):
self.proc.terminate()
def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
def __init__(self,
cli_args: List[str],
*,
auto_port: bool = True,
num_gpus: int = 1) -> None:
if auto_port:
if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port"
......@@ -105,10 +108,11 @@ class RemoteOpenAIServer:
self.host = str(args.host or 'localhost')
self.port = int(args.port)
self._runner = self._RemoteRunner.remote( # type: ignore
cli_args,
wait_url=self.url_for("health"),
wait_timeout=self.MAX_SERVER_START_WAIT_S)
self._runner = ray.remote(num_gpus=num_gpus)(
self._RemoteRunner).remote(
cli_args,
wait_url=self.url_for("health"),
wait_timeout=self.MAX_SERVER_START_WAIT_S)
self._wait_until_ready()
......
......@@ -39,8 +39,8 @@ def test_swap() -> None:
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
# Randomly initialize the cache.
gpu_cache = worker.cache_engine.gpu_cache
cpu_cache = worker.cache_engine.cpu_cache
gpu_cache = worker.cache_engine[0].gpu_cache
cpu_cache = worker.cache_engine[0].cpu_cache
num_layers = len(gpu_cache)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
......
......@@ -27,6 +27,17 @@ logger = init_logger(__name__)
_GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [
"AquilaModel",
"AquilaForCausalLM",
"InternLMForCausalLM",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
]
class ModelConfig:
"""Configuration for the model.
......@@ -258,6 +269,13 @@ class ModelConfig:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pipeline_parallel_size = parallel_config.pipeline_parallel_size
architectures = getattr(self.hf_config, "architectures", [])
if not all(arch in _PP_SUPPORTED_MODELS
for arch in architectures) and pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) "
......@@ -665,9 +683,10 @@ class ParallelConfig:
self._verify_args()
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if (self.pipeline_parallel_size > 1
and self.distributed_executor_backend == "mp"):
raise NotImplementedError("Pipeline parallelism is not supported "
"yet with multiprocessing.")
if self.distributed_executor_backend not in ("ray", "mp", None):
raise ValueError(
"Unrecognized distributed executor backend. Supported values "
......
......@@ -471,6 +471,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused.
......
......@@ -317,6 +317,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
computed_seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
......
......@@ -256,6 +256,7 @@ class Scheduler:
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
......@@ -273,11 +274,19 @@ class Scheduler:
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version)
num_gpu_blocks = cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= pipeline_parallel_size
num_cpu_blocks = cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= pipeline_parallel_size
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment