Commit 223b9488 authored by Baber's avatar Baber
Browse files

types

parent 7cef4d38
from __future__ import annotations
import abc import abc
import asyncio import asyncio
import copy import copy
...@@ -8,16 +10,9 @@ from functools import cached_property ...@@ -8,16 +10,9 @@ from functools import cached_property
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Dict,
Iterable,
List,
Literal, Literal,
NamedTuple, NamedTuple,
Optional,
Tuple,
Union,
) )
...@@ -36,18 +31,21 @@ from importlib.util import find_spec ...@@ -36,18 +31,21 @@ from importlib.util import find_spec
from io import BytesIO from io import BytesIO
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable, Iterable
from PIL import Image from PIL import Image
from lm_eval.api.instance import Instance
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]] LogLikelihoodInputs = tuple[tuple[str, str], list[int], list[int]]
# utility class to keep track of json encoded chats # utility class to keep track of json encoded chats
...@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple): ...@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding) return self.prompt.encode(encoding)
def create_image_prompt( def create_image_prompt(imgs: list[Image.Image], chat: dict, fmt: str = "PNG") -> dict:
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
""" """
Parameters Parameters
...@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM): ...@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model: str = None, model: str = None,
pretrained: str = None, # `model` takes precedence over `pretrained` when passed. pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None, base_url: str = None,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
# Loglikelihood tasks require a tokenizer to calculate context lengths, # Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs. # however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False # use tokenized_requests=False
tokenizer_backend: Optional[ tokenizer_backend: Literal["tiktoken", "huggingface", "None", "none"]
Literal["tiktoken", "huggingface", "None", "none"] | None = "huggingface",
] = "huggingface",
truncate: bool = False, truncate: bool = False,
# number of concurrent requests. More useful if not batching # number of concurrent requests. More useful if not batching
num_concurrent: int = 1, num_concurrent: int = 1,
max_retries: int = 3, max_retries: int = 3,
max_gen_toks: int = 256, max_gen_toks: int = 256,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
seed: int = 1234, seed: int = 1234,
max_length: Optional[int] = 2048, max_length: int | None = 2048,
add_bos_token: bool = False, add_bos_token: bool = False,
custom_prefix_token_id: int = None, custom_prefix_token_id: int = None,
# send the requests as tokens or strings # send the requests as tokens or strings
tokenized_requests: bool = True, tokenized_requests: bool = True,
trust_remote_code: bool = False, trust_remote_code: bool = False,
revision: Optional[str] = "main", revision: str | None = "main",
use_fast_tokenizer: bool = True, use_fast_tokenizer: bool = True,
verify_certificate: bool = True, verify_certificate: bool = True,
eos_string: str = None, eos_string: str = None,
# timeout in seconds # timeout in seconds
timeout: int = 300, timeout: int = 300,
header: Optional[Dict[str, str]] = None, header: dict[str, str] | None = None,
max_images: int = 1, max_images: int = 1,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM): ...@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@abc.abstractmethod @abc.abstractmethod
def _create_payload( def _create_payload(
self, self,
messages: Union[List[List[int]], List[dict], List[str], str], messages: list[list[int]] | list[dict] | list[str] | str,
*, *,
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[dict] = None, gen_kwargs: dict | None = None,
seed: int = 1234, seed: int = 1234,
eos: str = None, eos: str | None = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API.""" """This method is responsible for creating the json payload that will be sent to the API."""
...@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM): ...@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def create_message( def create_message(
self, self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: list[list[int]] | list[str] | list[JsonChatStr],
generate=False, generate=False,
) -> Union[List[List[int]], List[dict], List[str], str]: ) -> list[list[int]] | list[dict] | list[str] | str:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests""" """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr): if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...] # for chat completions we need to decode the json string to list[dict,...]
...@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM): ...@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def parse_logprobs( def parse_logprobs(
outputs: Union[Any, List[Any]], outputs: Any | list[Any],
tokens: List[List[int]] = None, tokens: list[list[int]] | None = None,
ctxlen: List[int] = None, ctxlen: list[int] | None = None,
**kwargs, **kwargs,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples""" """Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]: def parse_generations(outputs: Any | list[Any], **kwargs) -> list[str]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str""" """Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise NotImplementedError raise NotImplementedError
...@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM): ...@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
"""Must be defined for LM subclasses which implement Chat Templating. """Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used. Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
""" """
return "" return ""
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]: ) -> str | JsonChatStr:
"""Applies a chat template to a list of chat history between user and model.""" """Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
...@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM): ...@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
) )
else: # bit of a hack. We'll load back before sending to the API
# bit of a hack. We'll load back before sending to the API return JsonChatStr(
return JsonChatStr( json.dumps(
json.dumps( [{**item, "type": "text"} for item in chat_history],
[{**item, "type": "text"} for item in chat_history], ensure_ascii=False,
ensure_ascii=False,
)
) )
)
@cached_property @cached_property
def eot_token_id(self) -> Optional[int]: def eot_token_id(self) -> int | None:
if self.tokenizer is None: if self.tokenizer is None:
return None return None
else: else:
if self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken": if self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token return self.tokenizer.eot_token
@cached_property @cached_property
def eos_string(self) -> Optional[str]: def eos_string(self) -> str | None:
if self._eos_string: if self._eos_string:
return self._eos_string return self._eos_string
elif self.tokenizer is not None: if self.tokenizer is not None:
if self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken": if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token]) return self.tokenizer.decode([self.tokenizer.eot_token])
else: else:
eval_logger.warning( eval_logger.warning(
...@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM): ...@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return None return None
@cached_property @cached_property
def prefix_token_id(self) -> Optional[int]: def prefix_token_id(self) -> int | None:
if self.tokenizer is None: if self.tokenizer is None:
return None return None
else: else:
...@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM): ...@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if self.tokenizer.bos_token_id is not None: if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
else:
return self.tokenizer.eot_token return self.tokenizer.eot_token
def tok_encode( def tok_encode(
self, self,
string: str, string: str,
left_truncate_len: int = None, left_truncate_len: int | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
**kwargs, **kwargs,
) -> Union[List[List[int]], List[int], List[str]]: ) -> list[list[int]] | list[int] | list[str]:
if self.tokenizer_backend is None: if self.tokenizer_backend is None:
return [string] return [string]
elif self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
# by default for CausalLM - false or self.add_bos_token is set # by default for CausalLM - false or self.add_bos_token is set
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM): ...@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding = self.tokenizer.encode_batch(string) encoding = self.tokenizer.encode_batch(string)
return encoding return encoding
def decode_batch(self, tokens: List[List[int]]) -> List[str]: def decode_batch(self, tokens: list[list[int]]) -> list[str] | None:
if self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
return self.tokenizer.batch_decode(tokens) return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken": if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens) return self.tokenizer.decode_batch(tokens)
def model_call( def model_call(
self, self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: list[list[int]] | list[str] | list[JsonChatStr],
*, *,
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[Dict] = None, gen_kwargs: dict | None = None,
**kwargs, **kwargs,
) -> Optional[dict]: ) -> dict | None:
# !!! Copy: shared dict for each request, need new object !!! # !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs) gen_kwargs = copy.deepcopy(gen_kwargs)
try: try:
...@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM): ...@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except RetryError: except RetryError:
eval_logger.error( eval_logger.exception(
"API request failed after multiple retries. Please check the API status." "API request failed after multiple retries. Please check the API status."
) )
return None return None
...@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM): ...@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self, self,
session: ClientSession, session: ClientSession,
sem: asyncio.Semaphore, sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: list[list[int]] | list[str] | list[JsonChatStr],
*, *,
generate: bool = True, generate: bool = True,
cache_keys: list = None, cache_keys: list | None = None,
ctxlens: Optional[List[int]] = None, ctxlens: list[int] | None = None,
gen_kwargs: Optional[Dict] = None, gen_kwargs: dict | None = None,
**kwargs, **kwargs,
) -> Union[List[str], List[Tuple[float, bool]], None]: ) -> list[str] | list[tuple[float, bool]] | None:
# !!! Copy: shared dict for each request, need new object !!! # !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs) gen_kwargs = copy.deepcopy(gen_kwargs)
payload = self._create_payload( payload = self._create_payload(
...@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM): ...@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem.release() sem.release()
def batch_loglikelihood_requests( def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]] self, chunks: Iterable[list[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]: ) -> tuple[list[list[int]], list[int], list[tuple[str, str]]]:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
cache_keys = [] cache_keys = []
...@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM): ...@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys: list, cache_keys: list,
*, *,
generate: bool = True, generate: bool = True,
ctxlens: List[int] = None, ctxlens: list[int] | None = None,
**kwargs, **kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]: ) -> list[list[str]] | list[list[tuple[float, bool]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests) ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate) conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent) sem = asyncio.Semaphore(self._concurrent)
...@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM): ...@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API") return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: def _loglikelihood_tokens(self, requests, **kwargs) -> list[tuple[float, bool]]:
assert self.tokenizer is not None, ( assert self.tokenizer is not None, (
"Tokenizer is required for loglikelihood tasks to compute context lengths." "Tokenizer is required for loglikelihood tasks to compute context lengths."
) )
res = [] res = []
def _collate(req: LogLikelihoodInputs): def _collate(req: LogLikelihoodInputs):
"""Defines the key for the sorted method""" """Defines the key for the sorted method."""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch # - to know the size of a batch when going through the list, you know the first one is always the batch
...@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM): ...@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
def _collate_gen(_requests): def _collate_gen(_requests):
...@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM): ...@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm): for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
......
...@@ -682,8 +682,7 @@ class HFLM(TemplateLM): ...@@ -682,8 +682,7 @@ class HFLM(TemplateLM):
) )
if peft: if peft:
from peft import PeftModel from peft import PeftModel, __version__ as PEFT_VERSION
from peft import __version__ as PEFT_VERSION
if model_kwargs.get("load_in_4bit") and vparse(PEFT_VERSION) < vparse( if model_kwargs.get("load_in_4bit") and vparse(PEFT_VERSION) < vparse(
"0.4.0" "0.4.0"
......
from __future__ import annotations
import logging import logging
import os import os
from functools import cached_property from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI from lm_eval.models.api_models import TemplateAPI
...@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def _create_payload( def _create_payload(
self, self,
messages: Union[List[List[int]], List[dict], List[str], str], messages: list[list[int]] | list[dict] | list[str] | str,
generate=False, generate=False,
gen_kwargs: Optional[dict] = None, gen_kwargs: dict | None = None,
seed: int = 1234, seed: int = 1234,
eos=None, eos=None,
**kwargs, **kwargs,
...@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed": seed, "seed": seed,
**gen_kwargs, **gen_kwargs,
} }
else: return {
return { "model": self.model,
"model": self.model, "prompt": messages,
"prompt": messages, "temperature": 0,
"temperature": 0, "max_tokens": 1,
"max_tokens": 1, "logprobs": 1,
"logprobs": 1, "seed": seed,
"seed": seed, "echo": True,
"echo": True, }
}
@staticmethod @staticmethod
def parse_logprobs( def parse_logprobs(
outputs: Union[Dict, List[Dict]], outputs: dict | list[dict],
tokens: List[List[int]] = None, tokens: list[list[int]] = None,
ctxlens: List[int] = None, ctxlens: list[int] = None,
**kwargs, **kwargs,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return res return res
@staticmethod @staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict | None = None,
seed=1234, seed=1234,
eos=None, eos=None,
**kwargs, **kwargs,
...@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
} }
@staticmethod @staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def tok_encode( def tok_encode(
self, self,
string: Union[str, Any], string: str | Any,
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=None, add_special_tokens=None,
**kwargs, **kwargs,
) -> Union[List[str], List[int], Any]: ) -> list[str] | list[int] | Any:
return string return string
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
...@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
) )
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
return "" return ""
...@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict = None,
seed=1234, seed=1234,
......
from __future__ import annotations
import copy import copy
import gc import gc
import logging import logging
...@@ -7,7 +9,7 @@ from importlib.util import find_spec ...@@ -7,7 +9,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
from more_itertools import distribute from more_itertools import distribute
...@@ -113,30 +115,30 @@ class VLLM(TemplateLM): ...@@ -113,30 +115,30 @@ class VLLM(TemplateLM):
self, self,
pretrained: str, pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None, revision: str | None = None,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: str | None = None,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: str | None = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
max_batch_size=None, max_batch_size: int | None = None,
max_length: int = None, max_length: int | None = None,
max_model_len: int = None, max_model_len: int | None = None,
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
chat_template_args: Optional[dict] = None, chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: str | None = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -172,7 +174,7 @@ class VLLM(TemplateLM): ...@@ -172,7 +174,7 @@ class VLLM(TemplateLM):
"swap_space": int(swap_space), "swap_space": int(swap_space),
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"enable_lora": True if lora_local_path else False, "enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank), "max_lora_rank": int(max_lora_rank),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
...@@ -300,7 +302,7 @@ class VLLM(TemplateLM): ...@@ -300,7 +302,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
...@@ -337,14 +339,14 @@ class VLLM(TemplateLM): ...@@ -337,14 +339,14 @@ class VLLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: Union[str, List[str]], string: str | list[str],
left_truncate_len: int = None, left_truncate_len: int = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
) -> Union[List[int], List[List[int]]]: ) -> list[int] | list[list[int]]:
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -362,7 +364,7 @@ class VLLM(TemplateLM): ...@@ -362,7 +364,7 @@ class VLLM(TemplateLM):
def _model_generate( def _model_generate(
self, self,
requests: List[List[int]] = None, requests: list[list[int]] = None,
generate: bool = False, generate: bool = False,
sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None, sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None,
): ):
...@@ -379,8 +381,8 @@ class VLLM(TemplateLM): ...@@ -379,8 +381,8 @@ class VLLM(TemplateLM):
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: List["SamplingParams"], sampling_params: list["SamplingParams"],
requests: List[List[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: "LoRARequest",
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
...@@ -454,7 +456,7 @@ class VLLM(TemplateLM): ...@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs: if dead_procs:
raise RuntimeError( raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly" f"Worker processes {dead_procs} died unexpectedly"
) ) from None
continue continue
results = [rank_res[i] for i in range(len(procs))] results = [rank_res[i] for i in range(len(procs))]
...@@ -481,14 +483,14 @@ class VLLM(TemplateLM): ...@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs = self.model.generate( outputs = self.model.generate(
[TokensPrompt(prompt_token_ids=request) for request in requests], [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request, lora_request=self.lora_request,
) )
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
adaptive_batch_size = len(requests) adaptive_batch_size = len(requests)
...@@ -503,7 +505,7 @@ class VLLM(TemplateLM): ...@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -556,13 +558,13 @@ class VLLM(TemplateLM): ...@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode( context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token context, add_special_tokens=self.add_bos_token
) )
requests = [ requests = [
...@@ -638,7 +640,7 @@ class VLLM(TemplateLM): ...@@ -638,7 +640,7 @@ class VLLM(TemplateLM):
) )
# cache generations # cache generations
for output, context in zip(cont, context): for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text( generated_text = postprocess_generated_text(
...@@ -646,7 +648,7 @@ class VLLM(TemplateLM): ...@@ -646,7 +648,7 @@ class VLLM(TemplateLM):
) )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context_, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
...@@ -656,9 +658,9 @@ class VLLM(TemplateLM): ...@@ -656,9 +658,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -679,7 +681,7 @@ class VLLM(TemplateLM): ...@@ -679,7 +681,7 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for _cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
) > self.max_length: ) > self.max_length:
...@@ -717,7 +719,7 @@ class VLLM(TemplateLM): ...@@ -717,7 +719,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment