Unverified Commit eca6926b authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

vllm: handle max_length better and substitute Collator (#1241)



* copies max_length from huggingface

* handle max_length properly

* get tokens from inputs

* substitute Collator for Reorderer

* `batch=auto` if using data_parallel

* nit

* cleanup

* update code comments

* `ray.shutdown()` after calling method if data_parallel_size > 1

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 25a15379
import copy import copy
from collections import defaultdict
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import (
Collator,
divide,
eval_logger,
get_rolling_token_windows,
make_disjoint_window,
)
try: try:
import ray
from ray.util.multiprocessing import Pool from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
eval_logger = utils.eval_logger eval_logger = eval_logger
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727 # adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]): def run_inference_one_model(
# gpu_id = [x for x in gpu_id] model_args: dict, sampling_params, requests: List[List[int]]
# os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id) ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
...@@ -43,7 +49,7 @@ class VLLM(LM): ...@@ -43,7 +49,7 @@ class VLLM(LM):
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[Literal["awq"]] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: Union[str, int] = 1,
...@@ -86,10 +92,23 @@ class VLLM(LM): ...@@ -86,10 +92,23 @@ class VLLM(LM):
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
} }
self.batch_size = (
"auto"
if isinstance(batch_size, str) and "auto" in batch_size
else batch_size
)
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
else: else:
self.model_args["worker_use_ray"] = True self.model_args["worker_use_ray"] = True
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
from transformers import AutoConfig
self._config = AutoConfig.from_pretrained(
pretrained, trust_remote_code=trust_remote_code, revision=revision
)
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
tokenizer if tokenizer else pretrained, tokenizer if tokenizer else pretrained,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
...@@ -97,7 +116,6 @@ class VLLM(LM): ...@@ -97,7 +116,6 @@ class VLLM(LM):
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.batch_size = "auto" if batch_size.startswith("auto:") else batch_size
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
@property @property
...@@ -109,9 +127,18 @@ class VLLM(LM): ...@@ -109,9 +127,18 @@ class VLLM(LM):
def max_length(self): def max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
return self._max_length return self._max_length
if hasattr(self.tokenizer, "model_max_length"): if self.data_parallel_size <= 1:
return self.tokenizer.model_max_length return self.model.llm_engine.model_config.max_model_len
return self._DEFAULT_MAX_LENGTH else:
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self):
...@@ -157,13 +184,13 @@ class VLLM(LM): ...@@ -157,13 +184,13 @@ class VLLM(LM):
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1:
requests = [ requests = [list(x) for x in divide(requests, self.data_parallel_size)]
list(x) for x in utils.divide(requests, self.data_parallel_size)
]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel_size) as pool: with Pool(self.data_parallel_size) as pool:
results = pool.starmap(run_inference_one_model, inputs) results = pool.starmap(run_inference_one_model, inputs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results # flatten results
return [item for sublist in results for item in sublist] return [item for sublist in results for item in sublist]
...@@ -172,7 +199,6 @@ class VLLM(LM): ...@@ -172,7 +199,6 @@ class VLLM(LM):
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
) )
return outputs return outputs
def _encode_pair( def _encode_pair(
...@@ -212,8 +238,8 @@ class VLLM(LM): ...@@ -212,8 +238,8 @@ class VLLM(LM):
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, make_disjoint_window,
utils.get_rolling_token_windows( get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length - 1, max_seq_len=self.max_length - 1,
...@@ -236,8 +262,7 @@ class VLLM(LM): ...@@ -236,8 +262,7 @@ class VLLM(LM):
return loglikelihoods return loglikelihoods
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = []
re_ords = {}
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
...@@ -253,84 +278,73 @@ class VLLM(LM): ...@@ -253,84 +278,73 @@ class VLLM(LM):
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), tuple(_requests[0][1]) return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x[1])) re_ords = Collator(requests, _collate_gen, grouping=True)
for key, reqs in grouper.get_grouped().items(): chunks = re_ords.get_batched(
# within each set of reqs for given kwargs, we reorder by token length, descending. n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
re_ords[key] = utils.Reorderer(requests, _collate_gen) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): for chunk in chunks:
chunks = utils.chunks( context_and_encoding, all_gen_kwargs = zip(*chunk)
re_ord.get_reordered(), context, context_encoding = zip(*context_and_encoding)
n=int(self.batch_size) if self.batch_size != "auto" else 0, # we assume all gen kwargs in the batch are the same
fn=None, # this is safe to assume because the `grouper` object ensures it.
) gen_kwargs = all_gen_kwargs[0]
for chunk in chunks: # unpack our keyword arguments.
context_and_encoding, all_gen_kwargs = zip(*chunk) until = None
context, context_encoding = zip(*context_and_encoding) if isinstance(gen_kwargs, dict):
# we assume all gen kwargs in the batch are the same kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# this is safe to assume because the `grouper` object ensures it. if "until" in kwargs.keys():
gen_kwargs = all_gen_kwargs[0] until = kwargs.pop("until")
# unpack our keyword arguments. if isinstance(until, str):
until = None until = [until]
if isinstance(gen_kwargs, dict): elif not isinstance(until, list):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 raise ValueError(
if "until" in kwargs.keys(): f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
until = kwargs.pop("until") )
if isinstance(until, str): else:
until = [until] raise ValueError(
elif not isinstance(until, list): f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tokenizer.decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# TODO: max_length in kwargs
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
) )
if not until:
until = [self.tokenizer.decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
)
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
res[key].append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close() pbar.close()
# reorder all group of results back to original unsorted form
return grouper.get_original(res) return re_ords.get_original(res)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
...@@ -343,16 +357,15 @@ class VLLM(LM): ...@@ -343,16 +357,15 @@ class VLLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) # Reorder requests by length and batch
re_ord = Collator(requests, sort_fn=_collate)
chunks = utils.chunks( chunks = re_ord.get_batched(
re_ord.get_reordered(), n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
n=int(self.batch_size) if self.batch_size != "auto" else 0,
fn=None,
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(total=len(requests), disable=disable_tqdm)
for chunk in chunks: for chunk in chunks:
inps = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :] inp = (context_enc + continuation_enc)[-(self.max_length) :]
...@@ -360,18 +373,18 @@ class VLLM(LM): ...@@ -360,18 +373,18 @@ class VLLM(LM):
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - (self.max_length)
) )
inps.append(inp) inputs.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
outputs = self._model_generate(requests=inps, generate=False) outputs = self._model_generate(requests=inputs, generate=False)
for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip( for output, ctxlen, (cache_key, _, _), inp in zip(
outputs, ctxlens, chunk outputs, ctxlens, chunk, inputs
): ):
answer = self._parse_logprobs( answer = self._parse_logprobs(
(context_enc + continuation_enc), tokens=inp,
output, outputs=output,
ctxlen, ctxlen=ctxlen,
) )
res.append(answer) res.append(answer)
...@@ -379,7 +392,7 @@ class VLLM(LM): ...@@ -379,7 +392,7 @@ class VLLM(LM):
# partial caching # partial caching
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
return re_ord.get_original(res) return re_ord.get_original(res)
...@@ -388,9 +401,9 @@ class VLLM(LM): ...@@ -388,9 +401,9 @@ class VLLM(LM):
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
Tokens from context+continuations Input tokens (potentially left-truncated)
:param outputs: RequestOutput :param outputs: RequestOutput
Contains prompt Contains prompt_logprobs
:param ctxlen: int :param ctxlen: int
Length of context (so we can slice them away and only keep the predictions) Length of context (so we can slice them away and only keep the predictions)
:return: :return:
...@@ -400,11 +413,11 @@ class VLLM(LM): ...@@ -400,11 +413,11 @@ class VLLM(LM):
Whether argmax matches given continuation exactly Whether argmax matches given continuation exactly
""" """
# prompt_logprobs = [None, {}*len(context-1)] # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
continuation_logprobs_dicts = outputs.prompt_logprobs continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs # Calculate continuation_logprobs
# assume ctxlen always > 1 # assume ctxlen always >= 1
continuation_logprobs = sum( continuation_logprobs = sum(
logprob_dict.get(token) logprob_dict.get(token)
for token, logprob_dict in zip( for token, logprob_dict in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment