Unverified Commit 5a481f43 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[vllm] data parallel for V1 (#3011)

* add data_parallel for V1

* use Process instead of Queue

* ray used if V0 DP

* better error handling

* fix truncation warning comparison
parent 7aaceeec
import copy import copy
import gc
import inspect import inspect
import logging import logging
import os
from importlib.metadata import version from importlib.metadata import version
from importlib.util import find_spec from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from more_itertools import distribute from more_itertools import distribute
...@@ -29,6 +34,7 @@ try: ...@@ -29,6 +34,7 @@ try:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
from vllm.entrypoints.chat_utils import resolve_hf_chat_template from vllm.entrypoints.chat_utils import resolve_hf_chat_template
...@@ -41,6 +47,63 @@ if TYPE_CHECKING: ...@@ -41,6 +47,63 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
dp_size: int,
local_dp_rank: int,
dp_master_port: int,
dp_master_ip: str = "127.0.0.1",
) -> None:
"""
Worker process for vLLM multiprocessing.
Initializes a vLLM engine, processes requests, and puts results or errors
onto the result_queue.
"""
if not requests:
result_queue.put((local_dp_rank, []))
return None
os.environ["VLLM_DP_RANK"] = os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = str(dp_master_ip)
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
llm = None
try:
llm = LLM(**model_args)
res = llm.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
)
# Give engines time to pause their processing loops before exiting."
sleep(1)
result_queue.put((local_dp_rank, res))
except Exception as e:
error_message = f"Worker {local_dp_rank} failed during generation: {type(e).__name__}: {str(e)}"
eval_logger.error(error_message, exc_info=True)
result_queue.put((local_dp_rank, {"error": error_message}))
finally:
if llm is not None:
try:
del llm
gc.collect()
except Exception as e_cleanup:
eval_logger.warning(
f"Worker {local_dp_rank} encountered an error during LLM cleanup: {type(e_cleanup).__name__}: {str(e_cleanup)}",
exc_info=True,
)
return None
@register_model("vllm") @register_model("vllm")
class VLLM(TemplateLM): class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
...@@ -83,7 +146,7 @@ class VLLM(TemplateLM): ...@@ -83,7 +146,7 @@ class VLLM(TemplateLM):
assert max_length is None or max_model_len is None, ( assert max_length is None or max_model_len is None, (
"Either max_length or max_model_len may be provided, but not both" "Either max_length or max_model_len may be provided, but not both"
) )
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
self._max_length = max_model_len if max_model_len is not None else max_length self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel_size = int(data_parallel_size) self.data_parallel_size = int(data_parallel_size)
...@@ -98,6 +161,7 @@ class VLLM(TemplateLM): ...@@ -98,6 +161,7 @@ class VLLM(TemplateLM):
"trust_remote_code": trust_remote_code, "trust_remote_code": trust_remote_code,
"tensor_parallel_size": int(tensor_parallel_size), "tensor_parallel_size": int(tensor_parallel_size),
"max_model_len": int(self._max_length) if self._max_length else None, "max_model_len": int(self._max_length) if self._max_length else None,
"max_num_seqs": kwargs.get("max_num_seqs", max_batch_size),
"swap_space": int(swap_space), "swap_space": int(swap_space),
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
...@@ -115,7 +179,11 @@ class VLLM(TemplateLM): ...@@ -115,7 +179,11 @@ class VLLM(TemplateLM):
eval_logger.warning( eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached." "You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
) )
self.model_args["distributed_executor_backend"] = "ray" self.model_args["distributed_executor_backend"] = (
"ray"
if not self.V1
else self.model_args.get("distributed_executor_backend", None)
)
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
...@@ -279,7 +347,7 @@ class VLLM(TemplateLM): ...@@ -279,7 +347,7 @@ class VLLM(TemplateLM):
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn # also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973 # see https://github.com/vllm-project/vllm/issues/973
...@@ -310,14 +378,83 @@ class VLLM(TemplateLM): ...@@ -310,14 +378,83 @@ class VLLM(TemplateLM):
ray.shutdown() ray.shutdown()
# flatten results # flatten results
return undistribute(results) return undistribute(results)
elif self.data_parallel_size > 1:
# based on https://github.com/vllm-project/vllm/blob/a04720bc36401d831cb048c3917b9e58173d9c1d/examples/offline_inference/data_parallel.py
dp_size = self.data_parallel_size
dp_master_ip = os.environ.get("VLLM_DP_MASTER_IP", "127.0.0.1")
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
procs, resq = [], Queue()
# We use Process as it is non-daemonic
try:
for rank, req in enumerate(requests):
proc = Process(
target=_vllm_mp_worker,
args=(
self.model_args.copy(),
sampling_params,
req,
self.lora_request,
resq,
dp_size,
rank,
dp_master_port,
dp_master_ip,
),
)
proc.start()
procs.append(proc)
# Collect results
rank_res = {}
while len(rank_res) < len(procs):
try:
rank, result = resq.get(timeout=30)
if isinstance(result, dict) and "error" in result:
raise RuntimeError(result["error"])
rank_res[rank] = result
except Empty:
dead_procs = [
idx
for idx, p in enumerate(procs)
if not p.is_alive() and idx not in rank_res
]
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
continue
results = [rank_res[i] for i in range(len(procs))]
return undistribute(results)
# cleanup
finally:
try:
resq.close()
resq.join_thread()
except Exception:
eval_logger.debug(
"Failed to close vllm DP results queue", exc_info=True
)
for proc in procs:
proc.join(timeout=10)
if proc.is_alive():
proc.terminate()
proc.join(timeout=5)
if proc.is_alive():
proc.kill()
outputs = self.model.generate( else:
prompt_token_ids=requests, outputs = self.model.generate(
sampling_params=sampling_params, prompt_token_ids=requests,
use_tqdm=True if self.batch_size == "auto" else False, sampling_params=sampling_params,
lora_request=self.lora_request, use_tqdm=True if self.batch_size == "auto" else False,
) lora_request=self.lora_request,
return outputs )
return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
...@@ -507,8 +644,7 @@ class VLLM(TemplateLM): ...@@ -507,8 +644,7 @@ class VLLM(TemplateLM):
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
>= self.max_length ) > self.max_length:
):
eval_logger.warning( eval_logger.warning(
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context." f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment