Commit 6d63c2ce authored by Baber's avatar Baber
Browse files

types

parent 0087929e
from __future__ import annotations
import copy import copy
import logging import logging
import os import os
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
import torch import torch
...@@ -40,7 +42,7 @@ from lm_eval.models.utils import ( ...@@ -40,7 +42,7 @@ from lm_eval.models.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.quantizers import AutoQuantizationConfig from transformers.quantizers.auto import AutoQuantizationConfig
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -59,46 +61,43 @@ class HFLM(TemplateLM): ...@@ -59,46 +61,43 @@ class HFLM(TemplateLM):
def __init__( def __init__(
self, self,
pretrained: Union[str, transformers.PreTrainedModel], pretrained: str | transformers.PreTrainedModel,
backend: Literal["default", "causal", "seq2seq"] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: str | None = "main",
subfolder: str = "", subfolder: str = "",
tokenizer: Optional[ tokenizer: str
Union[ | transformers.PreTrainedTokenizer
str, | transformers.PreTrainedTokenizerFast
transformers.PreTrainedTokenizer, | None = None,
transformers.PreTrainedTokenizerFast, truncation: bool | None = False,
]
] = None,
truncation: Optional[bool] = False,
logits_cache: bool = True, logits_cache: bool = True,
max_length: Optional[int] = None, max_length: int | None = None,
device: Optional[str] = "cuda", device: str | None = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: str | torch.dtype | None = "auto",
softmax_dtype: Optional[Union[str, torch.dtype]] = None, softmax_dtype: str | torch.dtype | None = None,
mixed_precision_dtype: Optional[Union[str, torch.dtype]] = None, mixed_precision_dtype: str | torch.dtype | None = None,
batch_size: Optional[Union[int, str]] = 1, batch_size: int | str | None = 1,
max_batch_size: Optional[int] = 64, max_batch_size: int | None = 64,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: bool | None = True,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: bool | None = False,
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: int | str | None = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: int | str | None = None,
offload_folder: Optional[Union[str, os.PathLike]] = "./offload", offload_folder: str | os.PathLike | None = "./offload",
# PEFT, delta weights and quantization options # PEFT, delta weights and quantization options
peft: Optional[str] = None, peft: str | None = None,
delta: Optional[str] = None, delta: str | None = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: bool | str | None = False,
gptqmodel: Optional[bool] = False, gptqmodel: bool | None = False,
gguf_file: Optional[str] = None, gguf_file: str | None = None,
# end token for thinking, either the string or int token id. # end token for thinking, either the string or int token id.
# splits to get response after this token (if provided). # splits to get response after this token (if provided).
think_end_token: Union[str, int, None] = None, think_end_token: str | int | None = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -271,18 +270,19 @@ class HFLM(TemplateLM): ...@@ -271,18 +270,19 @@ class HFLM(TemplateLM):
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
if isinstance(pretrained, str): if isinstance(pretrained, str):
if gpus >= 1 or str(self.device) == "mps": if (gpus >= 1 or str(self.device) == "mps") and not (
parallelize or autogptq or hasattr(self, "accelerator")
):
# TODO: can remove this whole snippet except in the mps case, perhaps? # TODO: can remove this whole snippet except in the mps case, perhaps?
if not (parallelize or autogptq or hasattr(self, "accelerator")): # place model onto device requested manually,
# place model onto device requested manually, # if not using HF Accelerate or device_map
# if not using HF Accelerate or device_map # or any other option that preloads model onto device
# or any other option that preloads model onto device try:
try: self.model.to(self.device)
self.model.to(self.device) except ValueError:
except ValueError: eval_logger.debug(
eval_logger.debug( "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." )
)
# multigpu data-parallel support when launched with accelerate # multigpu data-parallel support when launched with accelerate
if gpus > 1: if gpus > 1:
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
...@@ -327,12 +327,12 @@ class HFLM(TemplateLM): ...@@ -327,12 +327,12 @@ class HFLM(TemplateLM):
def _get_accelerate_args( def _get_accelerate_args(
self, self,
parallelize: Optional[bool] = None, parallelize: bool | None = None,
device_map: Optional[str] = "auto", device_map: str | None = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: int | str | None = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: int | str | None = None,
offload_folder: Optional[str] = "./offload", offload_folder: str | None = "./offload",
gpus: Optional[int] = None, gpus: int | None = None,
) -> dict: ) -> dict:
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
...@@ -480,9 +480,9 @@ class HFLM(TemplateLM): ...@@ -480,9 +480,9 @@ class HFLM(TemplateLM):
def _get_backend( def _get_backend(
self, self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig], config: transformers.PretrainedConfig | transformers.AutoConfig,
backend: Literal["default", "causal", "seq2seq"] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -497,27 +497,20 @@ class HFLM(TemplateLM): ...@@ -497,27 +497,20 @@ class HFLM(TemplateLM):
if backend != "default": if backend != "default":
# if we've settled on non-default backend, use that manually # if we've settled on non-default backend, use that manually
if backend == "causal": if backend in ["causal", "seq2seq"]:
self.backend = backend
elif backend == "seq2seq":
self.backend = backend self.backend = backend
eval_logger.info( eval_logger.info(
f"Overrode HF model backend type, and using type '{self.backend}'" f"Overrode HF model backend type, and using type '{self.backend}'"
) )
else: else:
# determine and use the default HF backend for this model, based on its config + metadata. # determine and use the default HF backend for this model, based on its config + metadata.
if ( if self.config.model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
getattr(config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some # first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models. # these special cases should be treated as seq2seq models.
self.backend = "seq2seq" self.backend = "seq2seq"
eval_logger.debug(f"Using model type '{self.backend}'") eval_logger.debug(f"Using model type '{self.backend}'")
elif ( elif self.config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self.backend = "causal" self.backend = "causal"
eval_logger.debug(f"Using model type '{self.backend}'") eval_logger.debug(f"Using model type '{self.backend}'")
else: else:
...@@ -545,7 +538,7 @@ class HFLM(TemplateLM): ...@@ -545,7 +538,7 @@ class HFLM(TemplateLM):
pretrained: str, pretrained: str,
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
gguf_file: Optional[str] = None, gguf_file: str | None = None,
subfolder: str = "", subfolder: str = "",
) -> None: ) -> None:
"""Return the model config for HuggingFace models""" """Return the model config for HuggingFace models"""
...@@ -560,24 +553,24 @@ class HFLM(TemplateLM): ...@@ -560,24 +553,24 @@ class HFLM(TemplateLM):
def _create_model( def _create_model(
self, self,
pretrained: str, pretrained: str,
revision: Optional[str] = "main", revision: str | None = "main",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: str | torch.dtype | None = "auto",
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
# (accelerate naive PP (device_map) options) # (accelerate naive PP (device_map) options)
parallelize: Optional[bool] = False, parallelize: bool | None = False,
gpus: Optional[int] = None, gpus: int | None = None,
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: int | str | None = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: int | str | None = None,
offload_folder: Optional[str] = "./offload", offload_folder: str | None = "./offload",
# PEFT, delta weights and quantization options # PEFT, delta weights and quantization options
peft: Optional[str] = None, peft: str | None = None,
delta: Optional[str] = None, delta: str | None = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: bool | str | None = False,
gptqmodel: Optional[bool] = False, gptqmodel: bool | None = False,
gguf_file: Optional[str] = None, gguf_file: str | None = None,
quantization_config: Optional["AutoQuantizationConfig"] = None, quantization_config: AutoQuantizationConfig | None = None,
subfolder: str = "", subfolder: str = "",
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -598,7 +591,7 @@ class HFLM(TemplateLM): ...@@ -598,7 +591,7 @@ class HFLM(TemplateLM):
model_kwargs.update( model_kwargs.update(
self._get_accelerate_args( self._get_accelerate_args(
parallelize=parallelize, parallelize=parallelize,
device_map=kwargs.get("device_map", None), device_map=kwargs.get("device_map"),
max_memory_per_gpu=max_memory_per_gpu, max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory, max_cpu_memory=max_cpu_memory,
offload_folder=offload_folder, offload_folder=offload_folder,
...@@ -611,12 +604,11 @@ class HFLM(TemplateLM): ...@@ -611,12 +604,11 @@ class HFLM(TemplateLM):
assert transformers.__version__ >= "4.30.0", ( assert transformers.__version__ >= "4.30.0", (
"load_in_4bit requires transformers >= 4.30.0" "load_in_4bit requires transformers >= 4.30.0"
) )
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0" and (
if model_kwargs.get("load_in_4bit", None): model_kwargs.get("load_in_4bit")
if model_kwargs.get("bnb_4bit_compute_dtype", None): and (compute_dtype := model_kwargs.get("bnb_4bit_compute_dtype"))
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( ):
model_kwargs["bnb_4bit_compute_dtype"] model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(compute_dtype)
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
...@@ -641,7 +633,7 @@ class HFLM(TemplateLM): ...@@ -641,7 +633,7 @@ class HFLM(TemplateLM):
raise type(exception)( raise type(exception)(
"Tried to load auto_gptq, but auto-gptq is not installed ", "Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
) ) from exception
self._model = AutoGPTQForCausalLM.from_quantized( self._model = AutoGPTQForCausalLM.from_quantized(
pretrained, pretrained,
...@@ -660,7 +652,7 @@ class HFLM(TemplateLM): ...@@ -660,7 +652,7 @@ class HFLM(TemplateLM):
raise type(exception)( raise type(exception)(
"Tried to load gptqmodel, but gptqmodel is not installed ", "Tried to load gptqmodel, but gptqmodel is not installed ",
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`", "please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
) ) from exception
self._model = GPTQModel.from_quantized( self._model = GPTQModel.from_quantized(
pretrained, trust_remote_code=trust_remote_code, **model_kwargs pretrained, trust_remote_code=trust_remote_code, **model_kwargs
...@@ -672,12 +664,12 @@ class HFLM(TemplateLM): ...@@ -672,12 +664,12 @@ class HFLM(TemplateLM):
) )
if peft: if peft:
from peft import PeftModel from peft import PeftModel, __version__ as PEFT_VERSION
from peft import __version__ as PEFT_VERSION
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit") and version.parse(
if version.parse(PEFT_VERSION) < version.parse("0.4.0"): PEFT_VERSION
raise AssertionError("load_in_4bit requires peft >= 0.4.0") ) < version.parse("0.4.0"):
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
if self._model.config.vocab_size != len(self.tokenizer): if self._model.config.vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens # resize model for LoRAs with added tokens
eval_logger.info( eval_logger.info(
...@@ -703,11 +695,13 @@ class HFLM(TemplateLM): ...@@ -703,11 +695,13 @@ class HFLM(TemplateLM):
try: try:
param.data += _model_delta.state_dict()[name] param.data += _model_delta.state_dict()[name]
except KeyError: except KeyError:
raise KeyError(f"Delta model is missing weights for layer: {name}") raise KeyError(
f"Delta model is missing weights for layer: {name}"
) from None
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
f"Failed to add delta weights to layer {name}. Error: {e}" f"Failed to add delta weights to layer {name}. Error: {e}"
) ) from e
del _model_delta del _model_delta
...@@ -715,20 +709,17 @@ class HFLM(TemplateLM): ...@@ -715,20 +709,17 @@ class HFLM(TemplateLM):
def _create_tokenizer( def _create_tokenizer(
self, self,
pretrained: Union[str, transformers.PreTrainedModel], pretrained: str | transformers.PreTrainedModel,
tokenizer: Optional[ tokenizer: str
Union[ | transformers.PreTrainedTokenizer
str, | transformers.PreTrainedTokenizerFast
transformers.PreTrainedTokenizer, | None,
transformers.PreTrainedTokenizerFast, revision: str | None = "main",
] trust_remote_code: bool | None = False,
], use_fast_tokenizer: bool | None = True,
revision: Optional[str] = "main", gguf_file: str | None = None,
trust_remote_code: Optional[bool] = False, add_bos_token: bool | None = False,
use_fast_tokenizer: Optional[bool] = True, subfolder: str | None = "",
gguf_file: Optional[str] = None,
add_bos_token: Optional[bool] = False,
subfolder: Optional[str] = "",
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -760,8 +751,12 @@ class HFLM(TemplateLM): ...@@ -760,8 +751,12 @@ class HFLM(TemplateLM):
) )
else: else:
assert isinstance( assert isinstance(
tokenizer, transformers.PreTrainedTokenizer tokenizer,
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) (
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
),
)
self.tokenizer = tokenizer self.tokenizer = tokenizer
else: else:
# Get tokenizer based on 'pretrained' # Get tokenizer based on 'pretrained'
...@@ -838,7 +833,7 @@ class HFLM(TemplateLM): ...@@ -838,7 +833,7 @@ class HFLM(TemplateLM):
def tok_encode( def tok_encode(
self, string: str, left_truncate_len=None, add_special_tokens=None self, string: str, left_truncate_len=None, add_special_tokens=None
) -> List[int]: ) -> list[int]:
""" """ """ """
# default for None - empty dict, use predefined tokenizer param # default for None - empty dict, use predefined tokenizer param
# used for all models except for CausalLM or predefined value # used for all models except for CausalLM or predefined value
...@@ -864,11 +859,11 @@ class HFLM(TemplateLM): ...@@ -864,11 +859,11 @@ class HFLM(TemplateLM):
def tok_batch_encode( def tok_batch_encode(
self, self,
strings: List[str], strings: list[str],
padding_side: str = "left", padding_side: str = "left",
left_truncate_len: int = None, left_truncate_len: int = None,
truncation: bool = False, truncation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
...@@ -917,24 +912,26 @@ class HFLM(TemplateLM): ...@@ -917,24 +912,26 @@ class HFLM(TemplateLM):
A torch tensor of shape [batch, sequence, vocab] with the A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder logits returned from the model's decoder
""" """
with torch.no_grad(): with (
with torch.autocast( torch.no_grad(),
torch.autocast(
device_type=self.device.type, device_type=self.device.type,
dtype=self.mixed_precision_dtype, dtype=self.mixed_precision_dtype,
enabled=self.mixed_precision_dtype is not None, enabled=self.mixed_precision_dtype is not None,
): ),
if attn_mask is not None or labels is not None: ):
assert attn_mask is not None and labels is not None if attn_mask is not None or labels is not None:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM assert attn_mask is not None and labels is not None
return self.model( assert transformers.AutoModelForSeq2SeqLM == self.AUTO_MODEL_CLASS
input_ids=inps, attention_mask=attn_mask, labels=labels return self.model(
).logits input_ids=inps, attention_mask=attn_mask, labels=labels
else: ).logits
assert self.AUTO_MODEL_CLASS in ( else:
transformers.AutoModelForCausalLM, assert self.AUTO_MODEL_CLASS in (
transformers.AutoModelForVision2Seq, transformers.AutoModelForCausalLM,
) transformers.AutoModelForVision2Seq,
return self.model(inps).logits )
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set # temperature = 0.0 if not set
...@@ -942,7 +939,7 @@ class HFLM(TemplateLM): ...@@ -942,7 +939,7 @@ class HFLM(TemplateLM):
# remove temperature, as do_sample=False takes care of this # remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF # and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None) do_sample = generation_kwargs.get("do_sample")
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None: if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
...@@ -989,8 +986,8 @@ class HFLM(TemplateLM): ...@@ -989,8 +986,8 @@ class HFLM(TemplateLM):
return logits return logits
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM): ...@@ -1009,7 +1006,7 @@ class HFLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
...@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM): ...@@ -1093,14 +1090,14 @@ class HFLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
override_bs: int = None, override_bs: int = None,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): def _collate(req: tuple[tuple[str, str], list[int], list[int]]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM): ...@@ -1112,7 +1109,7 @@ class HFLM(TemplateLM):
toks = req[1] + req[2] toks = req[1] + req[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): def _lookup_one_token_cont(req: tuple[tuple[str, str], list[int], list[int]]):
"""Defines the key to group and lookup one-token continuations""" """Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)" # Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations. # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
...@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM): ...@@ -1286,7 +1283,7 @@ class HFLM(TemplateLM):
# original args. Otherwise, expands the logits batch dimension and yields each # original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings. # batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab] # logits -> [1, seq, vocab]
for request_str, cont_toks, logits in re_ord.get_cache( for request_str, cont_toks, logits in re_ord.get_cache( # noqa
req_str=request_str, req_str=request_str,
cxt_toks=ctx_tokens, cxt_toks=ctx_tokens,
cont_toks=cont_toks, cont_toks=cont_toks,
...@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM): ...@@ -1327,11 +1324,11 @@ class HFLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
def _collate(req: Tuple[str, dict]): def _collate(req: tuple[str, dict]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
...@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM): ...@@ -1394,7 +1391,7 @@ class HFLM(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
...@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM): ...@@ -1472,7 +1469,7 @@ class HFLM(TemplateLM):
return res return res
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
......
from __future__ import annotations
import copy import copy
import gc import gc
import inspect import inspect
...@@ -8,7 +10,7 @@ from importlib.util import find_spec ...@@ -8,7 +10,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
from more_itertools import distribute from more_itertools import distribute
...@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__) ...@@ -51,10 +53,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: "SamplingParams", sampling_params: SamplingParams,
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: LoRARequest,
result_queue: "Queue", result_queue: Queue,
dp_size: int, dp_size: int,
local_dp_rank: int, local_dp_rank: int,
dp_master_port: int, dp_master_port: int,
...@@ -114,30 +116,30 @@ class VLLM(TemplateLM): ...@@ -114,30 +116,30 @@ class VLLM(TemplateLM):
self, self,
pretrained: str, pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None, revision: str | None = None,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: str | None = None,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: str | None = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
max_batch_size=None, max_batch_size: int | None = None,
max_length: int = None, max_length: int | None = None,
max_model_len: int = None, max_model_len: int | None = None,
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str | None = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: str | None = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -173,7 +175,7 @@ class VLLM(TemplateLM): ...@@ -173,7 +175,7 @@ class VLLM(TemplateLM):
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
"device": str(device), "device": str(device),
"enable_lora": True if lora_local_path else False, "enable_lora": bool(lora_local_path),
"max_lora_rank": int(max_lora_rank), "max_lora_rank": int(max_lora_rank),
} }
self.model_args.update(kwargs) self.model_args.update(kwargs)
...@@ -304,7 +306,7 @@ class VLLM(TemplateLM): ...@@ -304,7 +306,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
...@@ -339,14 +341,14 @@ class VLLM(TemplateLM): ...@@ -339,14 +341,14 @@ class VLLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: Union[str, List[str]], string: str | list[str],
left_truncate_len: int = None, left_truncate_len: int = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
truncation: bool = False, truncation: bool = False,
) -> Union[List[int], List[List[int]]]: ) -> list[int] | list[list[int]]:
if not add_special_tokens: if not add_special_tokens:
add_special_tokens = False or self.add_bos_token add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer( encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens, add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
...@@ -364,10 +366,10 @@ class VLLM(TemplateLM): ...@@ -364,10 +366,10 @@ class VLLM(TemplateLM):
def _model_generate( def _model_generate(
self, self,
requests: List[List[int]] = None, requests: list[list[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, max_tokens: int = None,
stop: Optional[List[str]] = None, stop: list[str] | None = None,
**kwargs, **kwargs,
): ):
if generate: if generate:
...@@ -385,7 +387,7 @@ class VLLM(TemplateLM): ...@@ -385,7 +387,7 @@ class VLLM(TemplateLM):
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: SamplingParams,
requests: List[List[int]], requests: list[list[int]],
lora_request: LoRARequest, lora_request: LoRARequest,
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
...@@ -454,7 +456,7 @@ class VLLM(TemplateLM): ...@@ -454,7 +456,7 @@ class VLLM(TemplateLM):
if dead_procs: if dead_procs:
raise RuntimeError( raise RuntimeError(
f"Worker processes {dead_procs} died unexpectedly" f"Worker processes {dead_procs} died unexpectedly"
) ) from None
continue continue
results = [rank_res[i] for i in range(len(procs))] results = [rank_res[i] for i in range(len(procs))]
...@@ -481,14 +483,14 @@ class VLLM(TemplateLM): ...@@ -481,14 +483,14 @@ class VLLM(TemplateLM):
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, prompt_token_ids=requests,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request, lora_request=self.lora_request,
) )
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
adaptive_batch_size = len(requests) adaptive_batch_size = len(requests)
...@@ -503,7 +505,7 @@ class VLLM(TemplateLM): ...@@ -503,7 +505,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -556,13 +558,13 @@ class VLLM(TemplateLM): ...@@ -556,13 +558,13 @@ class VLLM(TemplateLM):
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode( context_encoding: list[list[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token context, add_special_tokens=self.add_bos_token
) )
requests = [ requests = [
...@@ -608,7 +610,7 @@ class VLLM(TemplateLM): ...@@ -608,7 +610,7 @@ class VLLM(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs:
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
...@@ -634,7 +636,7 @@ class VLLM(TemplateLM): ...@@ -634,7 +636,7 @@ class VLLM(TemplateLM):
) )
# cache generations # cache generations
for output, context in zip(cont, context): for output, context_ in zip(cont, context):
generated_text: str = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
generated_text = postprocess_generated_text( generated_text = postprocess_generated_text(
...@@ -642,7 +644,7 @@ class VLLM(TemplateLM): ...@@ -642,7 +644,7 @@ class VLLM(TemplateLM):
) )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context_, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
...@@ -652,9 +654,9 @@ class VLLM(TemplateLM): ...@@ -652,9 +654,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -675,7 +677,7 @@ class VLLM(TemplateLM): ...@@ -675,7 +677,7 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for _cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
) > self.max_length: ) > self.max_length:
...@@ -713,7 +715,7 @@ class VLLM(TemplateLM): ...@@ -713,7 +715,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment