Unverified Commit 5a481f43 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[vllm] data parallel for V1 (#3011)

* add data_parallel for V1

* use Process instead of Queue

* ray used if V0 DP

* better error handling

* fix truncation warning comparison
parent 7aaceeec
import copy
import gc
import inspect
import logging
import os
from importlib.metadata import version
from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from more_itertools import distribute
......@@ -29,6 +34,7 @@ try:
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template
......@@ -41,6 +47,63 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
dp_size: int,
local_dp_rank: int,
dp_master_port: int,
dp_master_ip: str = "127.0.0.1",
) -> None:
"""
Worker process for vLLM multiprocessing.
Initializes a vLLM engine, processes requests, and puts results or errors
onto the result_queue.
"""
if not requests:
result_queue.put((local_dp_rank, []))
return None
os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
llm = None
try:
llm = LLM(**model_args)
res = llm.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
)
# Give engines time to pause their processing loops before exiting."
sleep(1)
result_queue.put((local_dp_rank, res))
except Exception as e:
error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
eval_logger.error(error_message, exc_info=True)
result_queue.put((local_dp_rank, {"error": error_message}))
finally:
if llm is not None:
try:
del llm
gc.collect()
except Exception as e_cleanup:
eval_logger.warning(
f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
exc_info=True,
)
return None
@register_model("vllm")
class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048
......@@ -83,7 +146,7 @@ class VLLM(TemplateLM):
assert max_length is None or max_model_len is None, (
"Either max_length or max_model_len may be provided, but not both"
)
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size)
......@@ -98,6 +161,7 @@ class VLLM(TemplateLM):
"trust_remote_code": trust_remote_code,
"tensor_parallel_size": int(tensor_parallel_size),
"max_model_len": int(self._max_length) if self._max_length else None,
"max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
......@@ -115,7 +179,11 @@ class VLLM(TemplateLM):
eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
)
self.model_args["distributed_executor_backend"] = "ray"
self.model_args["distributed_executor_backend"] = (
"ray"
if not self.V1
else self.model_args.get("distributed_executor_backend", None)
)
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
......@@ -279,7 +347,7 @@ class VLLM(TemplateLM):
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if self.data_parallel_size > 1:
if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
......@@ -310,14 +378,83 @@ class VLLM(TemplateLM):
ray.shutdown()
# flatten results
return undistribute(results)
elif self.data_parallel_size > 1:
# based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
dp_size = self.data_parallel_size
dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
procs, resq = [], Queue()
# We use Process as it is non-daemonic
try:
for rank, req in enumerate(requests):
proc = Process(
target=_vllm_mp_worker,
args=(
self.model_args.copy(),
sampling_params,
req,
self.lora_request,
resq,
dp_size,
rank,
dp_master_port,
dp_master_ip,
),
)
proc.start()
procs.append(proc)
# Collect results
rank_res = {}
while len(rank_res) < len(procs):
try:
rank, result = resq.get(timeout=30)
if isinstance(result, dict) and "error" in result:
raise RuntimeError(result["error"])
rank_res[rank] = result
except Empty:
dead_procs = [
idx
for idx, p in enumerate(procs)
if not p.is_alive() and idx not in rank_res
]
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
continue
results = [rank_res[i] for i in range(len(procs))]
return undistribute(results)
# cleanup
finally:
try:
resq.close()
resq.join_thread()
except Exception:
eval_logger.debug(
"Failed to close vllm DP results queue", exc_info=True
)
for proc in procs:
proc.join(timeout=10)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
return outputs
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
......@@ -507,8 +644,7 @@ class VLLM(TemplateLM):
for cache_key, context_enc, continuation_enc in chunk:
if (
full_length := len(context_enc + continuation_enc)
>= self.max_length
):
) > self.max_length:
eval_logger.warning(
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment