Commit 1f97a945 authored by Baber's avatar Baber
Browse files

types

parent 0087929e
from __future__ import annotations
import abc import abc
import asyncio import asyncio
import copy import copy
...@@ -8,16 +10,9 @@ from functools import cached_property ...@@ -8,16 +10,9 @@ from functools import cached_property
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Dict,
Iterable,
List,
Literal, Literal,
NamedTuple, NamedTuple,
Optional,
Tuple,
Union,
) )
...@@ -36,18 +31,21 @@ from importlib.util import find_spec ...@@ -36,18 +31,21 @@ from importlib.util import find_spec
from io import BytesIO from io import BytesIO
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable, Iterable
from PIL import Image from PIL import Image
from lm_eval.api.instance import Instance
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]] LogLikelihoodInputs = tuple[tuple[str, str], list[int], list[int]]
# utility class to keep track of json encoded chats # utility class to keep track of json encoded chats
...@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple): ...@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding) return self.prompt.encode(encoding)
def create_image_prompt( def create_image_prompt(imgs: list[Image.Image], chat: dict, fmt: str = "PNG") -> dict:
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
""" """
Parameters Parameters
...@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM): ...@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model: str = None, model: str = None,
pretrained: str = None, # `model` takes precedence over `pretrained` when passed. pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None, base_url: str = None,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
# Loglikelihood tasks require a tokenizer to calculate context lengths, # Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs. # however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False # use tokenized_requests=False
tokenizer_backend: Optional[ tokenizer_backend: Literal["tiktoken", "huggingface", "None", "none"]
Literal["tiktoken", "huggingface", "None", "none"] | None = "huggingface",
] = "huggingface",
truncate: bool = False, truncate: bool = False,
# number of concurrent requests. More useful if not batching # number of concurrent requests. More useful if not batching
num_concurrent: int = 1, num_concurrent: int = 1,
max_retries: int = 3, max_retries: int = 3,
max_gen_toks: int = 256, max_gen_toks: int = 256,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
seed: int = 1234, seed: int = 1234,
max_length: Optional[int] = 2048, max_length: int | None = 2048,
add_bos_token: bool = False, add_bos_token: bool = False,
custom_prefix_token_id: int = None, custom_prefix_token_id: int = None,
# send the requests as tokens or strings # send the requests as tokens or strings
tokenized_requests: bool = True, tokenized_requests: bool = True,
trust_remote_code: bool = False, trust_remote_code: bool = False,
revision: Optional[str] = "main", revision: str | None = "main",
use_fast_tokenizer: bool = True, use_fast_tokenizer: bool = True,
verify_certificate: bool = True, verify_certificate: bool = True,
eos_string: str = None, eos_string: str = None,
# timeout in seconds # timeout in seconds
timeout: int = 300, timeout: int = 300,
header: Optional[Dict[str, str]] = None, header: dict[str, str] | None = None,
max_images: int = 1, max_images: int = 1,
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM): ...@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@abc.abstractmethod @abc.abstractmethod
def _create_payload( def _create_payload(
self, self,
messages: Union[List[List[int]], List[dict], List[str], str], messages: list[list[int]] | list[dict] | list[str] | str,
*, *,
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[dict] = None, gen_kwargs: dict | None = None,
seed: int = 1234, seed: int = 1234,
eos: str = None, eos: str | None = None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API.""" """This method is responsible for creating the json payload that will be sent to the API."""
...@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM): ...@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def create_message( def create_message(
self, self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: list[list[int]] | list[str] | list[JsonChatStr],
generate=False, generate=False,
) -> Union[List[List[int]], List[dict], List[str], str]: ) -> list[list[int]] | list[dict] | list[str] | str:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests""" """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr): if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...] # for chat completions we need to decode the json string to list[dict,...]
...@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM): ...@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def parse_logprobs( def parse_logprobs(
outputs: Union[Any, List[Any]], outputs: Any | list[Any],
tokens: List[List[int]] = None, tokens: list[list[int]] | None = None,
ctxlen: List[int] = None, ctxlen: list[int] | None = None,
**kwargs, **kwargs,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples""" """Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]: def parse_generations(outputs: Any | list[Any], **kwargs) -> list[str]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str""" """Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise NotImplementedError raise NotImplementedError
...@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM): ...@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
"""Must be defined for LM subclasses which implement Chat Templating. """Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used. Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used. Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
""" """
return "" return ""
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]: ) -> str | JsonChatStr:
"""Applies a chat template to a list of chat history between user and model.""" """Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
...@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM): ...@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt, continue_final_message=not add_generation_prompt,
) )
else: # bit of a hack. We'll load back before sending to the API
# bit of a hack. We'll load back before sending to the API return JsonChatStr(
return JsonChatStr( json.dumps(
json.dumps( [{**item, "type": "text"} for item in chat_history],
[{**item, "type": "text"} for item in chat_history], ensure_ascii=False,
ensure_ascii=False,
)
) )
)
@cached_property @cached_property
def eot_token_id(self) -> Optional[int]: def eot_token_id(self) -> int | None:
if self.tokenizer is None: if self.tokenizer is None:
return None return None
else: else:
if self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken": if self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token return self.tokenizer.eot_token
@cached_property @cached_property
def eos_string(self) -> Optional[str]: def eos_string(self) -> str | None:
if self._eos_string: if self._eos_string:
return self._eos_string return self._eos_string
elif self.tokenizer is not None: if self.tokenizer is not None:
if self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken": if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token]) return self.tokenizer.decode([self.tokenizer.eot_token])
else: else:
eval_logger.warning( eval_logger.warning(
...@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM): ...@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return None return None
@cached_property @cached_property
def prefix_token_id(self) -> Optional[int]: def prefix_token_id(self) -> int | None:
if self.tokenizer is None: if self.tokenizer is None:
return None return None
else: else:
...@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM): ...@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if self.tokenizer.bos_token_id is not None: if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
else:
return self.tokenizer.eot_token return self.tokenizer.eot_token
def tok_encode( def tok_encode(
self, self,
string: str, string: str,
left_truncate_len: int = None, left_truncate_len: int | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
**kwargs, **kwargs,
) -> Union[List[List[int]], List[int], List[str]]: ) -> list[list[int]] | list[int] | list[str]:
if self.tokenizer_backend is None: if self.tokenizer_backend is None:
return [string] return [string]
elif self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
# by default for CausalLM - false or self.add_bos_token is set # by default for CausalLM - false or self.add_bos_token is set
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM): ...@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding = self.tokenizer.encode_batch(string) encoding = self.tokenizer.encode_batch(string)
return encoding return encoding
def decode_batch(self, tokens: List[List[int]]) -> List[str]: def decode_batch(self, tokens: list[list[int]]) -> list[str] | None:
if self.tokenizer_backend == "huggingface": if self.tokenizer_backend == "huggingface":
return self.tokenizer.batch_decode(tokens) return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken": if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens) return self.tokenizer.decode_batch(tokens)
def model_call( def model_call(
self, self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: list[list[int]] | list[str] | list[JsonChatStr],
*, *,
generate: bool = True, generate: bool = True,
gen_kwargs: Optional[Dict] = None, gen_kwargs: dict | None = None,
**kwargs, **kwargs,
) -> Optional[dict]: ) -> dict | None:
# !!! Copy: shared dict for each request, need new object !!! # !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs) gen_kwargs = copy.deepcopy(gen_kwargs)
try: try:
...@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM): ...@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
except RetryError: except RetryError:
eval_logger.error( eval_logger.exception(
"API request failed after multiple retries. Please check the API status." "API request failed after multiple retries. Please check the API status."
) )
return None return None
...@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM): ...@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self, self,
session: ClientSession, session: ClientSession,
sem: asyncio.Semaphore, sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]], messages: list[list[int]] | list[str] | list[JsonChatStr],
*, *,
generate: bool = True, generate: bool = True,
cache_keys: list = None, cache_keys: list | None = None,
ctxlens: Optional[List[int]] = None, ctxlens: list[int] | None = None,
gen_kwargs: Optional[Dict] = None, gen_kwargs: dict | None = None,
**kwargs, **kwargs,
) -> Union[List[str], List[Tuple[float, bool]], None]: ) -> list[str] | list[tuple[float, bool]] | None:
# !!! Copy: shared dict for each request, need new object !!! # !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs) gen_kwargs = copy.deepcopy(gen_kwargs)
payload = self._create_payload( payload = self._create_payload(
...@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM): ...@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem.release() sem.release()
def batch_loglikelihood_requests( def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]] self, chunks: Iterable[list[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]: ) -> tuple[list[list[int]], list[int], list[tuple[str, str]]]:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
cache_keys = [] cache_keys = []
...@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM): ...@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys: list, cache_keys: list,
*, *,
generate: bool = True, generate: bool = True,
ctxlens: List[int] = None, ctxlens: list[int] | None = None,
**kwargs, **kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]: ) -> list[list[str]] | list[list[tuple[float, bool]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests) ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate) conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent) sem = asyncio.Semaphore(self._concurrent)
...@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM): ...@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API") return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: def _loglikelihood_tokens(self, requests, **kwargs) -> list[tuple[float, bool]]:
assert self.tokenizer is not None, ( assert self.tokenizer is not None, (
"Tokenizer is required for loglikelihood tasks to compute context lengths." "Tokenizer is required for loglikelihood tasks to compute context lengths."
) )
res = [] res = []
def _collate(req: LogLikelihoodInputs): def _collate(req: LogLikelihoodInputs):
"""Defines the key for the sorted method""" """Defines the key for the sorted method."""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch # - to know the size of a batch when going through the list, you know the first one is always the batch
...@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM): ...@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
def _collate_gen(_requests): def _collate_gen(_requests):
...@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM): ...@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm): for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
......
from __future__ import annotations
import copy import copy
import logging import logging
import os import os
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
import torch import torch
...@@ -40,7 +42,7 @@ from lm_eval.models.utils import ( ...@@ -40,7 +42,7 @@ from lm_eval.models.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.quantizers import AutoQuantizationConfig from transformers.quantizers.auto import AutoQuantizationConfig
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -59,46 +61,43 @@ class HFLM(TemplateLM): ...@@ -59,46 +61,43 @@ class HFLM(TemplateLM):
def __init__( def __init__(
self, self,
pretrained: Union[str, transformers.PreTrainedModel], pretrained: str | transformers.PreTrainedModel,
backend: Literal["default", "causal", "seq2seq"] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: str | None = "main",
subfolder: str = "", subfolder: str = "",
tokenizer: Optional[ tokenizer: str
Union[ | transformers.PreTrainedTokenizer
str, | transformers.PreTrainedTokenizerFast
transformers.PreTrainedTokenizer, | None = None,
transformers.PreTrainedTokenizerFast, truncation: bool | None = False,
]
] = None,
truncation: Optional[bool] = False,
logits_cache: bool = True, logits_cache: bool = True,
max_length: Optional[int] = None, max_length: int | None = None,
device: Optional[str] = "cuda", device: str | None = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: str | torch.dtype | None = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None, softmax_dtype: str | torch.dtype | None = None,
mixed_precision_dtype: Optional[Union[str, torch.dtype]] = None, mixed_precision_dtype: str | torch.dtype | None = None,
batch_size: Optional[Union[int, str]] = 1, batch_size: int | str | None = 1,
max_batch_size: Optional[int] = 64, max_batch_size: int | None = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: bool | None = True,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: bool | None = False,
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: int | str | None = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: int | str | None = None,
offload_folder: Optional[Union[str, os.PathLike]] = "./offload", offload_folder: str | os.PathLike | None = "./offload",
# PEFT, delta weights and quantization options # PEFT, delta weights and quantization options
peft: Optional[str] = None, peft: str | None = None,
delta: Optional[str] = None, delta: str | None = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: bool | str | None = False,
gptqmodel: Optional[bool] = False, gptqmodel: bool | None = False,
gguf_file: Optional[str] = None, gguf_file: str | None = None,
# end token for thinking, either the string or int token id. # end token for thinking, either the string or int token id.
# splits to get response after this token (if provided). # splits to get response after this token (if provided).
think_end_token: Union[str, int, None] = None, think_end_token: str | int | None = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -271,18 +270,19 @@ class HFLM(TemplateLM): ...@@ -271,18 +270,19 @@ class HFLM(TemplateLM):
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
if isinstance(pretrained, str): if isinstance(pretrained, str):
if gpus >= 1 or str(self.device) == "mps": if (gpus >= 1 or str(self.device) == "mps") and not (
parallelize or autogptq or hasattr(self, "accelerator")
):
# TODO: can remove this whole snippet except in the mps case, perhaps? # TODO: can remove this whole snippet except in the mps case, perhaps?
if not (parallelize or autogptq or hasattr(self, "accelerator")): # place model onto device requested manually,
# place model onto device requested manually, # if not using HF Accelerate or device_map
# if not using HF Accelerate or device_map # or any other option that preloads model onto device
# or any other option that preloads model onto device try:
try: self.model.to(self.device)
self.model.to(self.device) except ValueError:
except ValueError: eval_logger.debug(
eval_logger.debug( "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." )
)
# multigpu data-parallel support when launched with accelerate # multigpu data-parallel support when launched with accelerate
if gpus > 1: if gpus > 1:
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
...@@ -327,12 +327,12 @@ class HFLM(TemplateLM): ...@@ -327,12 +327,12 @@ class HFLM(TemplateLM):
def _get_accelerate_args( def _get_accelerate_args(
self, self,
parallelize: Optional[bool] = None, parallelize: bool | None = None,
device_map: Optional[str] = "auto", device_map: str | None = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: int | str | None = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: int | str | None = None,
offload_folder: Optional[str] = "./offload", offload_folder: str | None = "./offload",
gpus: Optional[int] = None, gpus: int | None = None,
) -> dict: ) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
...@@ -480,9 +480,9 @@ class HFLM(TemplateLM): ...@@ -480,9 +480,9 @@ class HFLM(TemplateLM):
def _get_backend( def _get_backend(
self, self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig], config: transformers.PretrainedConfig | transformers.AutoConfig,
backend: Literal["default", "causal", "seq2seq"] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -497,27 +497,20 @@ class HFLM(TemplateLM): ...@@ -497,27 +497,20 @@ class HFLM(TemplateLM):
if backend != "default": if backend != "default":
# if we've settled on non-default backend, use that manually # if we've settled on non-default backend, use that manually
if backend == "causal": if backend in ["causal", "seq2seq"]:
self.backend = backend
elif backend == "seq2seq":
self.backend = backend self.backend = backend
eval_logger.info( eval_logger.info(
f"Overrode HF model backend type, and using type '{self.backend}'" f"Overrode HF model backend type, and using type '{self.backend}'"
) )
else: else:
# determine and use the default HF backend for this model, based on its config + metadata. # determine and use the default HF backend for this model, based on its config + metadata.
if ( if self.config.model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
getattr(config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some # first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models. # these special cases should be treated as seq2seq models.
self.backend = "seq2seq" self.backend = "seq2seq"
eval_logger.debug(f"Using model type '{self.backend}'") eval_logger.debug(f"Using model type '{self.backend}'")
elif ( elif self.config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self.backend = "causal" self.backend = "causal"
eval_logger.debug(f"Using model type '{self.backend}'") eval_logger.debug(f"Using model type '{self.backend}'")
else: else:
...@@ -545,7 +538,7 @@ class HFLM(TemplateLM): ...@@ -545,7 +538,7 @@ class HFLM(TemplateLM):
pretrained: str, pretrained: str,
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
gguf_file: Optional[str] = None, gguf_file: str | None = None,
subfolder: str = "", subfolder: str = "",
) -> None: ) -> None:
"""Return the model config for HuggingFace models""" """Return the model config for HuggingFace models"""
...@@ -560,24 +553,24 @@ class HFLM(TemplateLM): ...@@ -560,24 +553,24 @@ class HFLM(TemplateLM):
def _create_model( def _create_model(
self, self,
pretrained: str, pretrained: str,
revision: Optional[str] = "main", revision: str | None = "main",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: str | torch.dtype | None = "auto",
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
# (accelerate naive PP (device_map) options) # (accelerate naive PP (device_map) options)
parallelize: Optional[bool] = False, parallelize: bool | None = False,
gpus: Optional[int] = None, gpus: int | None = None,
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: int | str | None = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: int | str | None = None,
offload_folder: Optional[str] = "./offload", offload_folder: str | None = "./offload",
# PEFT, delta weights and quantization options # PEFT, delta weights and quantization options
peft: Optional[str] = None, peft: str | None = None,
delta: Optional[str] = None, delta: str | None = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: bool | str | None = False,
gptqmodel: Optional[bool] = False, gptqmodel: bool | None = False,
gguf_file: Optional[str] = None, gguf_file: str | None = None,
quantization_config: Optional["AutoQuantizationConfig"] = None, quantization_config: AutoQuantizationConfig | None = None,
subfolder: str = "", subfolder: str = "",
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -598,7 +591,7 @@ class HFLM(TemplateLM): ...@@ -598,7 +591,7 @@ class HFLM(TemplateLM):
model_kwargs.update( model_kwargs.update(
self._get_accelerate_args( self._get_accelerate_args(
parallelize=parallelize, parallelize=parallelize,
device_map=kwargs.get("device_map", None), device_map=kwargs.get("device_map"),
max_memory_per_gpu=max_memory_per_gpu, max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory, max_cpu_memory=max_cpu_memory,
offload_folder=offload_folder, offload_folder=offload_folder,
...@@ -611,12 +604,11 @@ class HFLM(TemplateLM): ...@@ -611,12 +604,11 @@ class HFLM(TemplateLM):
assert transformers.__version__ >= "4.30.0", ( assert transformers.__version__ >= "4.30.0", (
"load_in_4bit requires transformers >= 4.30.0" "load_in_4bit requires transformers >= 4.30.0"
) )
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0" and (
if model_kwargs.get("load_in_4bit", None): model_kwargs.get("load_in_4bit")
if model_kwargs.get("bnb_4bit_compute_dtype", None): and (compute_dtype := model_kwargs.get("bnb_4bit_compute_dtype"))
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( ):
model_kwargs["bnb_4bit_compute_dtype"] model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(compute_dtype)
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
...@@ -641,7 +633,7 @@ class HFLM(TemplateLM): ...@@ -641,7 +633,7 @@ class HFLM(TemplateLM):
raise type(exception)( raise type(exception)(
"Tried to load auto_gptq, but auto-gptq is not installed ", "Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
) ) from exception
self._model = AutoGPTQForCausalLM.from_quantized( self._model = AutoGPTQForCausalLM.from_quantized(
pretrained, pretrained,
...@@ -660,7 +652,7 @@ class HFLM(TemplateLM): ...@@ -660,7 +652,7 @@ class HFLM(TemplateLM):
raise type(exception)( raise type(exception)(
"Tried to load gptqmodel, but gptqmodel is not installed ", "Tried to load gptqmodel, but gptqmodel is not installed ",
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`", "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
) ) from exception
self._model = GPTQModel.from_quantized( self._model = GPTQModel.from_quantized(
pretrained, trust_remote_code=trust_remote_code, **model_kwargs pretrained, trust_remote_code=trust_remote_code, **model_kwargs
...@@ -672,12 +664,12 @@ class HFLM(TemplateLM): ...@@ -672,12 +664,12 @@ class HFLM(TemplateLM):
) )
if peft: if peft:
from peft import PeftModel from peft import PeftModel, __version__ as PEFT_VERSION
from peft import __version__ as PEFT_VERSION
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit") and version.parse(
if version.parse(PEFT_VERSION) < version.parse("0.4.0"): PEFT_VERSION
raise AssertionError("load_in_4bit requires peft >= 0.4.0") ) < version.parse("0.4.0"):
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
if self._model.config.vocab_size != len(self.tokenizer): if self._model.config.vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens # resize model for LoRAs with added tokens
eval_logger.info( eval_logger.info(
...@@ -703,11 +695,13 @@ class HFLM(TemplateLM): ...@@ -703,11 +695,13 @@ class HFLM(TemplateLM):
try: try:
param.data += _model_delta.state_dict()[name] param.data += _model_delta.state_dict()[name]
except KeyError: except KeyError:
raise KeyError(f"Delta model is missing weights for layer: {name}") raise KeyError(
f"Delta model is missing weights for layer: {name}"
) from None
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"Failed to add delta weights to layer {name}. Error: {e}" f"Failed to add delta weights to layer {name}. Error: {e}"
) ) from e
del _model_delta del _model_delta
...@@ -715,20 +709,17 @@ class HFLM(TemplateLM): ...@@ -715,20 +709,17 @@ class HFLM(TemplateLM):
def _create_tokenizer( def _create_tokenizer(
self, self,
pretrained: Union[str, transformers.PreTrainedModel], pretrained: str | transformers.PreTrainedModel,
tokenizer: Optional[ tokenizer: str
Union[ | transformers.PreTrainedTokenizer
str, | transformers.PreTrainedTokenizerFast
transformers.PreTrainedTokenizer, | None,
transformers.PreTrainedTokenizerFast, revision: str | None = "main",
] trust_remote_code: bool | None = False,
], use_fast_tokenizer: bool | None = True,
revision: Optional[str] = "main", gguf_file: str | None = None,
trust_remote_code: Optional[bool] = False, add_bos_token: bool | None = False,
use_fast_tokenizer: Optional[bool] = True, subfolder: str | None = "",
gguf_file: Optional[str] = None,
add_bos_token: Optional[bool] = False,
subfolder: Optional[str] = "",
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -760,8 +751,12 @@ class HFLM(TemplateLM): ...@@ -760,8 +751,12 @@ class HFLM(TemplateLM):
) )
else: else:
assert isinstance( assert isinstance(
tokenizer, transformers.PreTrainedTokenizer tokenizer,
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) (
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
),
)
self.tokenizer = tokenizer self.tokenizer = tokenizer
else: else:
# Get tokenizer based on 'pretrained' # Get tokenizer based on 'pretrained'
...@@ -838,7 +833,7 @@ class HFLM(TemplateLM): ...@@ -838,7 +833,7 @@ class HFLM(TemplateLM):
def tok_encode( def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]: ) -> list[int]:
""" """ """ """
# default for None - empty dict, use predefined tokenizer param # default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value # used for all models except for CausalLM or predefined value
...@@ -864,11 +859,11 @@ class HFLM(TemplateLM): ...@@ -864,11 +859,11 @@ class HFLM(TemplateLM):
def tok_batch_encode( def tok_batch_encode(
self, self,
strings: List[str], strings: list[str],
padding_side: str = "left", padding_side: str = "left",
left_truncate_len: int = None, left_truncate_len: int = None,
truncation: bool = False, truncation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
...@@ -917,24 +912,26 @@ class HFLM(TemplateLM): ...@@ -917,24 +912,26 @@ class HFLM(TemplateLM):
A torch tensor of shape [batch, sequence, vocab] with the A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder logits returned from the model's decoder
""" """
with torch.no_grad(): with (
with torch.autocast( torch.no_grad(),
torch.autocast(
device_type=self.device.type, device_type=self.device.type,
dtype=self.mixed_precision_dtype, dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None, enabled=self.mixed_precision_dtype is not None,
): ),
if attn_mask is not None or labels is not None: ):
assert attn_mask is not None and labels is not None if attn_mask is not None or labels is not None:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM assert attn_mask is not None and labels is not None
return self.model( assert transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS
input_ids=inps, attention_mask=attn_mask, labels=labels return self.model(
).logits input_ids=inps, attention_mask=attn_mask, labels=labels
else: ).logits
assert self.AUTO_MODEL_CLASS in ( else:
transformers.AutoModelForCausalLM, assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForVision2Seq, transformers.AutoModelForCausalLM,
) transformers.AutoModelForVision2Seq,
return self.model(inps).logits )
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set # temperature = 0.0 if not set
...@@ -942,7 +939,7 @@ class HFLM(TemplateLM): ...@@ -942,7 +939,7 @@ class HFLM(TemplateLM):
# remove temperature, as do_sample=False takes care of this # remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF # and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None) do_sample = generation_kwargs.get("do_sample")
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None: if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
...@@ -989,8 +986,8 @@ class HFLM(TemplateLM): ...@@ -989,8 +986,8 @@ class HFLM(TemplateLM):
return logits return logits
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM): ...@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
...@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM): ...@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
override_bs: int = None, override_bs: int = None,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): def _collate(req: tuple[tuple[str, str], list[int], list[int]]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM): ...@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM):
toks = req[1] + req[2] toks = req[1] + req[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]):
"""Defines the key to group and lookup one-token continuations""" """Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)" # Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
...@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM): ...@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM):
# original args. Otherwise, expands the logits batch dimension and yields each # original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings. # batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab] # logits -> [1, seq, vocab]
for request_str, cont_toks, logits in re_ord.get_cache( for request_str, cont_toks, logits in re_ord.get_cache( # noqa
req_str=request_str, req_str=request_str,
cxt_toks=ctx_tokens, cxt_toks=ctx_tokens,
cont_toks=cont_toks, cont_toks=cont_toks,
...@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM): ...@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
def _collate(req: Tuple[str, dict]): def _collate(req: tuple[str, dict]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM): ...@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
...@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM): ...@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM):
return res return res
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
......
from __future__ import annotations
import logging import logging
import os import os
from functools import cached_property from functools import cached_property
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI from lm_eval.models.api_models import TemplateAPI
...@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def _create_payload( def _create_payload(
self, self,
messages: Union[List[List[int]], List[dict], List[str], str], messages: list[list[int]] | list[dict] | list[str] | str,
generate=False, generate=False,
gen_kwargs: Optional[dict] = None, gen_kwargs: dict | None = None,
seed: int = 1234, seed: int = 1234,
eos=None, eos=None,
**kwargs, **kwargs,
...@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed": seed, "seed": seed,
**gen_kwargs, **gen_kwargs,
} }
else: return {
return { "model": self.model,
"model": self.model, "prompt": messages,
"prompt": messages, "temperature": 0,
"temperature": 0, "max_tokens": 1,
"max_tokens": 1, "logprobs": 1,
"logprobs": 1, "seed": seed,
"seed": seed, "echo": True,
"echo": True, }
}
@staticmethod @staticmethod
def parse_logprobs( def parse_logprobs(
outputs: Union[Dict, List[Dict]], outputs: dict | list[dict],
tokens: List[List[int]] = None, tokens: list[list[int]] = None,
ctxlens: List[int] = None, ctxlens: list[int] = None,
**kwargs, **kwargs,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return res return res
@staticmethod @staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict | None = None,
seed=1234, seed=1234,
eos=None, eos=None,
**kwargs, **kwargs,
...@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
} }
@staticmethod @staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]: def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = [] res = []
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
...@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def tok_encode( def tok_encode(
self, self,
string: Union[str, Any], string: str | Any,
left_truncate_len=None, left_truncate_len=None,
add_special_tokens=None, add_special_tokens=None,
**kwargs, **kwargs,
) -> Union[List[str], List[int], Any]: ) -> list[str] | list[int] | Any:
return string return string
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
...@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
) )
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
return "" return ""
...@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def _create_payload( def _create_payload(
self, self,
messages: List[Dict], messages: list[dict],
generate=False, generate=False,
gen_kwargs: dict = None, gen_kwargs: dict = None,
seed=1234, seed=1234,
......
from __future__ import annotations
import copy import copy
import gc import gc
import inspect import inspect
...@@ -8,7 +10,7 @@ from importlib.util import find_spec ...@@ -8,7 +10,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
from more_itertools import distribute from more_itertools import distribute
...@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__) ...@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: "SamplingParams", sampling_params: SamplingParams,
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: LoRARequest,
result_queue: "Queue", result_queue: Queue,
dp_size: int, dp_size: int,
local_dp_rank: int, local_dp_rank: int,
dp_master_port: int, dp_master_port: int,
...@@ -114,30 +116,30 @@ class VLLM(TemplateLM): ...@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
self, self,
pretrained: str, pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None, revision: str | None = None,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: str | None = None,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: str | None = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
max_batch_size=None, max_batch_size: int | None = None,
max_length: int = None, max_length: int | None = None,
max_model_len: int = None, max_model_len: int | None = None,
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: str | None = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -173,7 +175,7 @@ class VLLM(TemplateLM): ...@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"device": str(device), "device": str(device),
"enable_lora": True if lora_local_path else False, "enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank), "max_lora_rank": int(max_lora_rank),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
...@@ -304,7 +306,7 @@ class VLLM(TemplateLM): ...@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
...@@ -339,14 +341,14 @@ class VLLM(TemplateLM): ...@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: Union[str, List[str]], string: str | list[str],
left_truncate_len: int = None, left_truncate_len: int = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
) -> Union[List[int], List[List[int]]]: ) -> list[int] | list[list[int]]:
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -364,10 +366,10 @@ class VLLM(TemplateLM): ...@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
def _model_generate( def _model_generate(
self, self,
requests: List[List[int]] = None, requests: list[list[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, max_tokens: int = None,
stop: Optional[List[str]] = None, stop: list[str] | None = None,
**kwargs, **kwargs,
): ):
if generate: if generate:
...@@ -385,7 +387,7 @@ class VLLM(TemplateLM): ...@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: SamplingParams,
requests: List[List[int]], requests: list[list[int]],
lora_request: LoRARequest, lora_request: LoRARequest,
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
...@@ -454,7 +456,7 @@ class VLLM(TemplateLM): ...@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs: if dead_procs:
raise RuntimeError( raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly" f"Worker processes {dead_procs} died unexpectedly"
) ) from None
continue continue
results = [rank_res[i] for i in range(len(procs))] results = [rank_res[i] for i in range(len(procs))]
...@@ -481,14 +483,14 @@ class VLLM(TemplateLM): ...@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, prompt_token_ids=requests,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request, lora_request=self.lora_request,
) )
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
adaptive_batch_size = len(requests) adaptive_batch_size = len(requests)
...@@ -503,7 +505,7 @@ class VLLM(TemplateLM): ...@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -556,13 +558,13 @@ class VLLM(TemplateLM): ...@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode( context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token context, add_special_tokens=self.add_bos_token
) )
requests = [ requests = [
...@@ -608,7 +610,7 @@ class VLLM(TemplateLM): ...@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
...@@ -634,7 +636,7 @@ class VLLM(TemplateLM): ...@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
) )
# cache generations # cache generations
for output, context in zip(cont, context): for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text( generated_text = postprocess_generated_text(
...@@ -642,7 +644,7 @@ class VLLM(TemplateLM): ...@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
) )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context_, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
...@@ -652,9 +654,9 @@ class VLLM(TemplateLM): ...@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -675,7 +677,7 @@ class VLLM(TemplateLM): ...@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for _cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
) > self.max_length: ) > self.max_length:
...@@ -713,7 +715,7 @@ class VLLM(TemplateLM): ...@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment