Unverified Commit e5e35fca authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Vllm update DP+TP (#1508)

* use `@ray.remote` with distributed vLLM

* update versions

* bugfix

* unpin vllm

* fix pre-commit

* added version assertion error

* Revert "added version assertion error"

This reverts commit 8041e9b78e95eea9f4f4d0dc260115ba8698e9cc.

* added version assertion for DP

* expand DP note

* add warning

* nit

* pin vllm

* fix typos
parent ae79b121
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
exclude: ^tests/testdata/ exclude: ^tests/testdata/
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0 rev: v4.5.0
hooks: hooks:
- id: check-added-large-files - id: check-added-large-files
- id: check-ast - id: check-ast
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.1.8 rev: v0.2.2
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
...@@ -38,7 +38,7 @@ repos: ...@@ -38,7 +38,7 @@ repos:
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.1.0 rev: v2.2.6
hooks: hooks:
- id: codespell - id: codespell
exclude: > exclude: >
......
...@@ -887,7 +887,7 @@ class HFLM(TemplateLM): ...@@ -887,7 +887,7 @@ class HFLM(TemplateLM):
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations""" """Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)" # Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations. # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices. # speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group. # groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1] return req[-2] + req[-1][:-1]
......
import copy import copy
from importlib.metadata import version
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
from more_itertools import distribute from more_itertools import distribute
from packaging.version import parse as parse_version
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -18,7 +20,6 @@ from lm_eval.utils import ( ...@@ -18,7 +20,6 @@ from lm_eval.utils import (
try: try:
import ray import ray
from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -27,14 +28,6 @@ except ModuleNotFoundError: ...@@ -27,14 +28,6 @@ except ModuleNotFoundError:
eval_logger = eval_logger eval_logger = eval_logger
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
):
llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
@register_model("vllm") @register_model("vllm")
class VLLM(TemplateLM): class VLLM(TemplateLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
...@@ -61,6 +54,7 @@ class VLLM(TemplateLM): ...@@ -61,6 +54,7 @@ class VLLM(TemplateLM):
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
**kwargs,
): ):
super().__init__() super().__init__()
...@@ -93,6 +87,7 @@ class VLLM(TemplateLM): ...@@ -93,6 +87,7 @@ class VLLM(TemplateLM):
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
} }
self.model_args.update(kwargs)
self.batch_size = ( self.batch_size = (
"auto" "auto"
if isinstance(batch_size, str) and "auto" in batch_size if isinstance(batch_size, str) and "auto" in batch_size
...@@ -101,6 +96,12 @@ class VLLM(TemplateLM): ...@@ -101,6 +96,12 @@ class VLLM(TemplateLM):
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
else: else:
assert parse_version(version("vllm")) < parse_version(
"0.3.3"
), "data_parallel is only compatible with vllm < v0.3.3."
eval_logger.warning(
"You might experience occasional issues with model weight downloading when data_parallel is in use. To ensure stable performance, run with data_parallel_size=1 until the weights are downloaded and cached."
)
self.model_args["worker_use_ray"] = True self.model_args["worker_use_ray"] = True
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
...@@ -182,13 +183,26 @@ class VLLM(TemplateLM): ...@@ -182,13 +183,26 @@ class VLLM(TemplateLM):
temperature=0, prompt_logprobs=1, max_tokens=1 temperature=0, prompt_logprobs=1, max_tokens=1
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1:
# vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
# see https://github.com/vllm-project/vllm/issues/973
# note: this has changed on 0.3.3, and it only works now if num_gpus are set.
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
):
llm = LLM(**model_args)
return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params
)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = ((self.model_args, sampling_params, req) for req in requests)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
with Pool(self.data_parallel_size) as pool: results = ray.get(object_refs)
results = pool.starmap(run_inference_one_model, inputs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown() ray.shutdown()
# flatten results # flatten results
...@@ -286,7 +300,7 @@ class VLLM(TemplateLM): ...@@ -286,7 +300,7 @@ class VLLM(TemplateLM):
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}" f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
) )
# add EOS token to stop sequences # add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id) eos = self.tokenizer.decode(self.eot_token_id)
if not until: if not until:
until = [eos] until = [eos]
else: else:
......
...@@ -71,7 +71,7 @@ optimum = ["optimum[openvino]"] ...@@ -71,7 +71,7 @@ optimum = ["optimum[openvino]"]
promptsource = ["promptsource>=0.2.3"] promptsource = ["promptsource>=0.2.3"]
sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
testing = ["pytest", "pytest-cov", "pytest-xdist"] testing = ["pytest", "pytest-cov", "pytest-xdist"]
vllm = ["vllm<=0.2.5"] vllm = ["vllm==0.3.2"]
zeno = ["pandas", "zeno-client"] zeno = ["pandas", "zeno-client"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"] wandb = ["wandb>=0.16.3", "pandas", "numpy"]
all = [ all = [
......
...@@ -75,7 +75,7 @@ std::vector<std::string> clean_ngram(std::string const &input, ...@@ -75,7 +75,7 @@ std::vector<std::string> clean_ngram(std::string const &input,
gram_lengths.erase(gram_lengths.begin()); gram_lengths.erase(gram_lengths.begin());
gram_lengths.push_back(0); gram_lengths.push_back(0);
// Otherwise, continute building // Otherwise, continue building
} else { } else {
current_ngram += ' '; current_ngram += ' ';
gram_lengths.push_back(0); gram_lengths.push_back(0);
...@@ -165,7 +165,7 @@ clean_ngram_with_indices(std::string const &input, std::string const &ignore, ...@@ -165,7 +165,7 @@ clean_ngram_with_indices(std::string const &input, std::string const &ignore,
gram_start_indices.erase(gram_start_indices.begin()); gram_start_indices.erase(gram_start_indices.begin());
gram_start_indices.push_back(i + 1); gram_start_indices.push_back(i + 1);
// Otherwise, continute building // Otherwise, continue building
} else { } else {
current_ngram += ' '; current_ngram += ' ';
gram_lengths.push_back(0); gram_lengths.push_back(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment