Commit 6d63c2ce authored by Baber's avatar Baber
Browse files

types

parent 0087929e
This diff is collapsed.
from __future__ import annotations
import copy import copy
import gc import gc
import inspect import inspect
...@@ -8,7 +10,7 @@ from importlib.util import find_spec ...@@ -8,7 +10,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
from more_itertools import distribute from more_itertools import distribute
...@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__) ...@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: "SamplingParams", sampling_params: SamplingParams,
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: LoRARequest,
result_queue: "Queue", result_queue: Queue,
dp_size: int, dp_size: int,
local_dp_rank: int, local_dp_rank: int,
dp_master_port: int, dp_master_port: int,
...@@ -114,30 +116,30 @@ class VLLM(TemplateLM): ...@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
self, self,
pretrained: str, pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None, revision: str | None = None,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: str | None = None,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: str | None = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
max_batch_size=None, max_batch_size: int | None = None,
max_length: int = None, max_length: int | None = None,
max_model_len: int = None, max_model_len: int | None = None,
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: str | None = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -173,7 +175,7 @@ class VLLM(TemplateLM): ...@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"device": str(device), "device": str(device),
"enable_lora": True if lora_local_path else False, "enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank), "max_lora_rank": int(max_lora_rank),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
...@@ -304,7 +306,7 @@ class VLLM(TemplateLM): ...@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
...@@ -339,14 +341,14 @@ class VLLM(TemplateLM): ...@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: Union[str, List[str]], string: str | list[str],
left_truncate_len: int = None, left_truncate_len: int = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
) -> Union[List[int], List[List[int]]]: ) -> list[int] | list[list[int]]:
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -364,10 +366,10 @@ class VLLM(TemplateLM): ...@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
def _model_generate( def _model_generate(
self, self,
requests: List[List[int]] = None, requests: list[list[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, max_tokens: int = None,
stop: Optional[List[str]] = None, stop: list[str] | None = None,
**kwargs, **kwargs,
): ):
if generate: if generate:
...@@ -385,7 +387,7 @@ class VLLM(TemplateLM): ...@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: SamplingParams,
requests: List[List[int]], requests: list[list[int]],
lora_request: LoRARequest, lora_request: LoRARequest,
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
...@@ -454,7 +456,7 @@ class VLLM(TemplateLM): ...@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs: if dead_procs:
raise RuntimeError( raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly" f"Worker processes {dead_procs} died unexpectedly"
) ) from None
continue continue
results = [rank_res[i] for i in range(len(procs))] results = [rank_res[i] for i in range(len(procs))]
...@@ -481,14 +483,14 @@ class VLLM(TemplateLM): ...@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, prompt_token_ids=requests,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request, lora_request=self.lora_request,
) )
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
adaptive_batch_size = len(requests) adaptive_batch_size = len(requests)
...@@ -503,7 +505,7 @@ class VLLM(TemplateLM): ...@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -556,13 +558,13 @@ class VLLM(TemplateLM): ...@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode( context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token context, add_special_tokens=self.add_bos_token
) )
requests = [ requests = [
...@@ -608,7 +610,7 @@ class VLLM(TemplateLM): ...@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
...@@ -634,7 +636,7 @@ class VLLM(TemplateLM): ...@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
) )
# cache generations # cache generations
for output, context in zip(cont, context): for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text( generated_text = postprocess_generated_text(
...@@ -642,7 +644,7 @@ class VLLM(TemplateLM): ...@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
) )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context_, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
...@@ -652,9 +654,9 @@ class VLLM(TemplateLM): ...@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -675,7 +677,7 @@ class VLLM(TemplateLM): ...@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for _cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
) > self.max_length: ) > self.max_length:
...@@ -713,7 +715,7 @@ class VLLM(TemplateLM): ...@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment