"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "f30abd090a1d02377a1211a8c8f5b10deac0e763"
Commit 223b9488 authored by Baber's avatar Baber
Browse files

types

parent 7cef4d38
from __future__ import annotations
import abc
import asyncio
import copy
......@@ -8,16 +10,9 @@ from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Tuple,
Union,
)
......@@ -36,18 +31,21 @@ from importlib.util import find_spec
from io import BytesIO
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING:
from collections.abc import Awaitable, Iterable
from PIL import Image
from lm_eval.api.instance import Instance
eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
LogLikelihoodInputs = tuple[tuple[str, str], list[int], list[int]]
# utility class to keep track of json encoded chats
......@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding)
def create_image_prompt(
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
def create_image_prompt(imgs: list[Image.Image], chat: dict, fmt: str = "PNG") -> dict:
"""
Parameters
......@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model: str = None,
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None,
tokenizer: Optional[str] = None,
tokenizer: str | None = None,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
tokenizer_backend: Optional[
Literal["tiktoken", "huggingface", "None", "none"]
] = "huggingface",
tokenizer_backend: Literal["tiktoken", "huggingface", "None", "none"]
| None = "huggingface",
truncate: bool = False,
# number of concurrent requests. More useful if not batching
num_concurrent: int = 1,
max_retries: int = 3,
max_gen_toks: int = 256,
batch_size: Union[str, int] = 1,
batch_size: str | int = 1,
seed: int = 1234,
max_length: Optional[int] = 2048,
max_length: int | None = 2048,
add_bos_token: bool = False,
custom_prefix_token_id: int = None,
# send the requests as tokens or strings
tokenized_requests: bool = True,
trust_remote_code: bool = False,
revision: Optional[str] = "main",
revision: str | None = "main",
use_fast_tokenizer: bool = True,
verify_certificate: bool = True,
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
header: Optional[Dict[str, str]] = None,
header: dict[str, str] | None = None,
max_images: int = 1,
**kwargs,
) -> None:
......@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@abc.abstractmethod
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
messages: list[list[int]] | list[dict] | list[str] | str,
*,
generate: bool = True,
gen_kwargs: Optional[dict] = None,
gen_kwargs: dict | None = None,
seed: int = 1234,
eos: str = None,
eos: str | None = None,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
......@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def create_message(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
messages: list[list[int]] | list[str] | list[JsonChatStr],
generate=False,
) -> Union[List[List[int]], List[dict], List[str], str]:
) -> list[list[int]] | list[dict] | list[str] | str:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...]
......@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@staticmethod
@abc.abstractmethod
def parse_logprobs(
outputs: Union[Any, List[Any]],
tokens: List[List[int]] = None,
ctxlen: List[int] = None,
outputs: Any | list[Any],
tokens: list[list[int]] | None = None,
ctxlen: list[int] | None = None,
**kwargs,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
def parse_generations(outputs: Any | list[Any], **kwargs) -> list[str]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise NotImplementedError
......@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@property
def tokenizer_name(self) -> str:
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]:
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str | JsonChatStr:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
......@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
)
@cached_property
def eot_token_id(self) -> Optional[int]:
def eot_token_id(self) -> int | None:
if self.tokenizer is None:
return None
else:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken":
if self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token
@cached_property
def eos_string(self) -> Optional[str]:
def eos_string(self) -> str | None:
if self._eos_string:
return self._eos_string
elif self.tokenizer is not None:
if self.tokenizer is not None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token])
else:
eval_logger.warning(
......@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return None
@cached_property
def prefix_token_id(self) -> Optional[int]:
def prefix_token_id(self) -> int | None:
if self.tokenizer is None:
return None
else:
......@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
else:
return self.tokenizer.eot_token
return self.tokenizer.eot_token
def tok_encode(
self,
string: str,
left_truncate_len: int = None,
left_truncate_len: int | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
**kwargs,
) -> Union[List[List[int]], List[int], List[str]]:
) -> list[list[int]] | list[int] | list[str]:
if self.tokenizer_backend is None:
return [string]
elif self.tokenizer_backend == "huggingface":
if self.tokenizer_backend == "huggingface":
# by default for CausalLM - false or self.add_bos_token is set
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
encoding: list[list[int]] | list[int] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
......@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding = self.tokenizer.encode_batch(string)
return encoding
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
def decode_batch(self, tokens: list[list[int]]) -> list[str] | None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken":
if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens)
def model_call(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
messages: list[list[int]] | list[str] | list[JsonChatStr],
*,
generate: bool = True,
gen_kwargs: Optional[Dict] = None,
gen_kwargs: dict | None = None,
**kwargs,
) -> Optional[dict]:
) -> dict | None:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
try:
......@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response.raise_for_status()
return response.json()
except RetryError:
eval_logger.error(
eval_logger.exception(
"API request failed after multiple retries. Please check the API status."
)
return None
......@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self,
session: ClientSession,
sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
messages: list[list[int]] | list[str] | list[JsonChatStr],
*,
generate: bool = True,
cache_keys: list = None,
ctxlens: Optional[List[int]] = None,
gen_kwargs: Optional[Dict] = None,
cache_keys: list | None = None,
ctxlens: list[int] | None = None,
gen_kwargs: dict | None = None,
**kwargs,
) -> Union[List[str], List[Tuple[float, bool]], None]:
) -> list[str] | list[tuple[float, bool]] | None:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
payload = self._create_payload(
......@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem.release()
def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
self, chunks: Iterable[list[LogLikelihoodInputs]]
) -> tuple[list[list[int]], list[int], list[tuple[str, str]]]:
inputs = []
ctxlens = []
cache_keys = []
......@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys: list,
*,
generate: bool = True,
ctxlens: List[int] = None,
ctxlens: list[int] | None = None,
**kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
) -> list[list[str]] | list[list[tuple[float, bool]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent)
......@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
def _loglikelihood_tokens(self, requests, **kwargs) -> list[tuple[float, bool]]:
assert self.tokenizer is not None, (
"Tokenizer is required for loglikelihood tasks to compute context lengths."
)
res = []
def _collate(req: LogLikelihoodInputs):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method."""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
......@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
def _collate_gen(_requests):
......@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res)
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
......
......@@ -682,8 +682,7 @@ class HFLM(TemplateLM):
)
if peft:
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from peft import PeftModel, __version__ as PEFT_VERSION
if model_kwargs.get("load_in_4bit") and vparse(PEFT_VERSION) < vparse(
"0.4.0"
......
from __future__ import annotations
import logging
import os
from functools import cached_property
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import TemplateAPI
......@@ -26,9 +28,9 @@ class LocalCompletionsAPI(TemplateAPI):
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
messages: list[list[int]] | list[dict] | list[str] | str,
generate=False,
gen_kwargs: Optional[dict] = None,
gen_kwargs: dict | None = None,
seed: int = 1234,
eos=None,
**kwargs,
......@@ -50,24 +52,23 @@ class LocalCompletionsAPI(TemplateAPI):
"seed": seed,
**gen_kwargs,
}
else:
return {
"model": self.model,
"prompt": messages,
"temperature": 0,
"max_tokens": 1,
"logprobs": 1,
"seed": seed,
"echo": True,
}
return {
"model": self.model,
"prompt": messages,
"temperature": 0,
"max_tokens": 1,
"logprobs": 1,
"seed": seed,
"echo": True,
}
@staticmethod
def parse_logprobs(
outputs: Union[Dict, List[Dict]],
tokens: List[List[int]] = None,
ctxlens: List[int] = None,
outputs: dict | list[dict],
tokens: list[list[int]] = None,
ctxlens: list[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
......@@ -88,7 +89,7 @@ class LocalCompletionsAPI(TemplateAPI):
return res
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
......@@ -130,9 +131,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
def _create_payload(
self,
messages: List[Dict],
messages: list[dict],
generate=False,
gen_kwargs: dict = None,
gen_kwargs: dict | None = None,
seed=1234,
eos=None,
**kwargs,
......@@ -160,7 +161,7 @@ class LocalChatCompletion(LocalCompletionsAPI):
}
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
def parse_generations(outputs: dict | list[dict], **kwargs) -> list[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
......@@ -173,11 +174,11 @@ class LocalChatCompletion(LocalCompletionsAPI):
def tok_encode(
self,
string: Union[str, Any],
string: str | Any,
left_truncate_len=None,
add_special_tokens=None,
**kwargs,
) -> Union[List[str], List[int], Any]:
) -> list[str] | list[int] | Any:
return string
def loglikelihood(self, requests, **kwargs):
......@@ -219,7 +220,7 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
)
return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
return ""
......@@ -261,7 +262,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
def _create_payload(
self,
messages: List[Dict],
messages: list[dict],
generate=False,
gen_kwargs: dict = None,
seed=1234,
......
from __future__ import annotations
import copy
import gc
import logging
......@@ -7,7 +9,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal
import jinja2
from more_itertools import distribute
......@@ -113,30 +115,30 @@ class VLLM(TemplateLM):
self,
pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer: Optional[str] = None,
revision: str | None = None,
trust_remote_code: bool | None = False,
tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
tokenizer_revision: str | None = None,
add_bos_token: bool | None = False,
prefix_token_id: int | None = None,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
quantization: str | None = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
batch_size: str | int = 1,
max_batch_size: int | None = None,
max_length: int | None = None,
max_model_len: int | None = None,
seed: int = 1234,
gpu_memory_utilization: float = 0.9,
data_parallel_size: int = 1,
lora_local_path: str = None,
lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
chat_template_args: Optional[dict] = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
think_end_token: str | None = None,
max_lora_rank: int = 16,
**kwargs,
):
......@@ -172,7 +174,7 @@ class VLLM(TemplateLM):
"swap_space": int(swap_space),
"quantization": quantization,
"seed": int(seed),
"enable_lora": True if lora_local_path else False,
"enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank),
}
self.model_args.update(kwargs)
......@@ -300,7 +302,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
......@@ -337,14 +339,14 @@ class VLLM(TemplateLM):
def tok_encode(
self,
string: Union[str, List[str]],
string: str | list[str],
left_truncate_len: int = None,
add_special_tokens: bool = False,
truncation: bool = False,
) -> Union[List[int], List[List[int]]]:
) -> list[int] | list[list[int]]:
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
encoding: list[list[int]] | list[int] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
......@@ -362,7 +364,7 @@ class VLLM(TemplateLM):
def _model_generate(
self,
requests: List[List[int]] = None,
requests: list[list[int]] = None,
generate: bool = False,
sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None,
):
......@@ -379,8 +381,8 @@ class VLLM(TemplateLM):
@ray.remote
def run_inference_one_model(
model_args: dict,
sampling_params: List["SamplingParams"],
requests: List[List[int]],
sampling_params: list["SamplingParams"],
requests: list[list[int]],
lora_request: "LoRARequest",
):
llm = LLM(**model_args)
......@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs:
raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly"
)
) from None
continue
results = [rank_res[i] for i in range(len(procs))]
......@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs = self.model.generate(
[TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request,
)
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
adaptive_batch_size = None
if self.batch_size == "auto":
adaptive_batch_size = len(requests)
......@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map(
make_disjoint_window,
get_rolling_token_windows(
......@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode(
context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token
)
requests = [
......@@ -638,7 +640,7 @@ class VLLM(TemplateLM):
)
# cache generations
for output, context in zip(cont, context):
for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text(
......@@ -646,7 +648,7 @@ class VLLM(TemplateLM):
)
res.append(generated_text)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text
"generate_until", (context_, gen_kwargs), generated_text
)
pbar.update(1)
......@@ -656,9 +658,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
res = []
def _collate(x):
......@@ -679,7 +681,7 @@ class VLLM(TemplateLM):
for chunk in chunks:
inputs = []
ctxlens = []
for cache_key, context_enc, continuation_enc in chunk:
for _cache_key, context_enc, continuation_enc in chunk:
if (
full_length := len(context_enc + continuation_enc)
) > self.max_length:
......@@ -717,7 +719,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment