Commit 6d63c2ce authored by Baber's avatar Baber
Browse files

types

parent 0087929e
This diff is collapsed.
from __future__ import annotations
import copy
import gc
import inspect
......@@ -8,7 +10,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal
import jinja2
from more_itertools import distribute
......@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
sampling_params: SamplingParams,
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
lora_request: LoRARequest,
result_queue: Queue,
dp_size: int,
local_dp_rank: int,
dp_master_port: int,
......@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
self,
pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer: Optional[str] = None,
revision: str | None = None,
trust_remote_code: bool | None = False,
tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
tokenizer_revision: str | None = None,
add_bos_token: bool | None = False,
prefix_token_id: int | None = None,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
quantization: str | None = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
batch_size: str | int = 1,
max_batch_size: int | None = None,
max_length: int | None = None,
max_model_len: int | None = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
device: str = "cuda",
data_parallel_size: int = 1,
lora_local_path: str = None,
lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
think_end_token: str | None = None,
max_lora_rank: int = 16,
**kwargs,
):
......@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
"quantization": quantization,
"seed": int(seed),
"device": str(device),
"enable_lora": True if lora_local_path else False,
"enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank),
}
self.model_args.update(kwargs)
......@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
......@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
def tok_encode(
self,
string: Union[str, List[str]],
string: str | list[str],
left_truncate_len: int = None,
add_special_tokens: bool = False,
truncation: bool = False,
) -> Union[List[int], List[List[int]]]:
) -> list[int] | list[list[int]]:
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
encoding: list[list[int]] | list[int] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
......@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
def _model_generate(
self,
requests: List[List[int]] = None,
requests: list[list[int]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
stop: list[str] | None = None,
**kwargs,
):
if generate:
......@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
def run_inference_one_model(
model_args: dict,
sampling_params: SamplingParams,
requests: List[List[int]],
requests: list[list[int]],
lora_request: LoRARequest,
):
llm = LLM(**model_args)
......@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
) from None
continue
results = [rank_res[i] for i in range(len(procs))]
......@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request,
)
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
adaptive_batch_size = None
if self.batch_size == "auto":
adaptive_batch_size = len(requests)
......@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map(
make_disjoint_window,
get_rolling_token_windows(
......@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode(
context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token
)
requests = [
......@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
......@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
)
# cache generations
for output, context in zip(cont, context):
for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text(
......@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
)
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
"generate_until", (context_, gen_kwargs), generated_text
)
pbar.update(1)
......@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
res = []
def _collate(x):
......@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
for chunk in chunks:
inputs = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
for _cache_key, context_enc, continuation_enc in chunk:
if (
full_length := len(context_enc + continuation_enc)
) > self.max_length:
......@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment