Unverified Commit f724be69 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

update pre-commit (#2632)

* update pre-commit
parent 9dda03d6
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
- id: mixed-line-ending - id: mixed-line-ending
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.4 rev: v0.9.2
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
......
...@@ -112,6 +112,4 @@ class ConfigurableGroup(abc.ABC): ...@@ -112,6 +112,4 @@ class ConfigurableGroup(abc.ABC):
return self._config.group return self._config.group
def __repr__(self): def __repr__(self):
return ( return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
f"ConfigurableGroup(group={self.group}," f"group_alias={self.group_alias})"
)
...@@ -527,9 +527,9 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): ...@@ -527,9 +527,9 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
assert ( assert metrics is not None, (
metrics is not None "Need to pass a list of each subtask's metric for this stderr aggregation"
), "Need to pass a list of each subtask's metric for this stderr aggregation" )
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics) assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation. # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
......
...@@ -17,13 +17,13 @@ def register_model(*names): ...@@ -17,13 +17,13 @@ def register_model(*names):
def decorate(cls): def decorate(cls):
for name in names: for name in names:
assert issubclass( assert issubclass(cls, LM), (
cls, LM f"Model '{name}' ({cls.__name__}) must extend LM class"
), f"Model '{name}' ({cls.__name__}) must extend LM class" )
assert ( assert name not in MODEL_REGISTRY, (
name not in MODEL_REGISTRY f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead." )
MODEL_REGISTRY[name] = cls MODEL_REGISTRY[name] = cls
return cls return cls
...@@ -48,9 +48,9 @@ func2task_index = {} ...@@ -48,9 +48,9 @@ func2task_index = {}
def register_task(name): def register_task(name):
def decorate(fn): def decorate(fn):
assert ( assert name not in TASK_REGISTRY, (
name not in TASK_REGISTRY f"task named '{name}' conflicts with existing registered task!"
), f"task named '{name}' conflicts with existing registered task!" )
TASK_REGISTRY[name] = fn TASK_REGISTRY[name] = fn
ALL_TASKS.add(name) ALL_TASKS.add(name)
...@@ -104,9 +104,9 @@ def register_metric(**args): ...@@ -104,9 +104,9 @@ def register_metric(**args):
]: ]:
if key in args: if key in args:
value = args[key] value = args[key]
assert ( assert value not in registry, (
value not in registry f"{key} named '{value}' conflicts with existing registered {key}!"
), f"{key} named '{value}' conflicts with existing registered {key}!" )
if key == "metric": if key == "metric":
registry[name] = fn registry[name] = fn
...@@ -140,9 +140,9 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable: ...@@ -140,9 +140,9 @@ def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
def register_aggregation(name: str): def register_aggregation(name: str):
def decorate(fn): def decorate(fn):
assert ( assert name not in AGGREGATION_REGISTRY, (
name not in AGGREGATION_REGISTRY f"aggregation named '{name}' conflicts with existing registered aggregation!"
), f"aggregation named '{name}' conflicts with existing registered aggregation!" )
AGGREGATION_REGISTRY[name] = fn AGGREGATION_REGISTRY[name] = fn
return fn return fn
......
...@@ -184,9 +184,9 @@ class FirstNSampler(ContextSampler): ...@@ -184,9 +184,9 @@ class FirstNSampler(ContextSampler):
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
""" """
assert ( assert n <= len(self.docs), (
n <= len(self.docs) f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available." )
return self.docs[:n] return self.docs[:n]
......
...@@ -151,7 +151,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d ...@@ -151,7 +151,7 @@ def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> d
elapsed = time.perf_counter() - start elapsed = time.perf_counter() - start
print(f"Read took {elapsed:0.5f} seconds.") print(f"Read took {elapsed:0.5f} seconds.")
print(f"Speed: {(os.path.getsize(file)/1000000.0)/elapsed}MB/second") print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
print(duplicates) print(duplicates)
......
...@@ -34,9 +34,9 @@ class TakeKFilter(Filter): ...@@ -34,9 +34,9 @@ class TakeKFilter(Filter):
# need resp to be subscriptable to check below # need resp to be subscriptable to check below
resps = list(resps) resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert len(resps[0]) >= self.k, (
len(resps[0]) >= self.k f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." )
return map(lambda r: r[: self.k], resps) return map(lambda r: r[: self.k], resps)
......
...@@ -43,9 +43,9 @@ class MapFilter(Filter): ...@@ -43,9 +43,9 @@ class MapFilter(Filter):
""" """
if mapping_dict is None: if mapping_dict is None:
mapping_dict = {} mapping_dict = {}
assert isinstance( assert isinstance(mapping_dict, dict), (
mapping_dict, dict "Provided mapping_dict is not a dictionary"
), "Provided mapping_dict is not a dictionary" )
self.mapping_dict = mapping_dict self.mapping_dict = mapping_dict
self.default_value = default_value self.default_value = default_value
......
...@@ -488,7 +488,7 @@ class EvaluationTracker: ...@@ -488,7 +488,7 @@ class EvaluationTracker:
else: else:
dataset_summary += f"{self.general_config_tracker.model_name}\n" dataset_summary += f"{self.general_config_tracker.model_name}\n"
dataset_summary += ( dataset_summary += (
f"The dataset is composed of {len(card_metadata)-1} configuration(s), each one corresponding to one of the evaluated task.\n\n" f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n' 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
'An additional configuration "results" store all the aggregated results of the run.\n\n' 'An additional configuration "results" store all the aggregated results of the run.\n\n'
...@@ -501,7 +501,7 @@ class EvaluationTracker: ...@@ -501,7 +501,7 @@ class EvaluationTracker:
) )
dataset_summary += ( dataset_summary += (
"## Latest results\n\n" "## Latest results\n\n"
f'These are the [latest results from run {latest_datetime}]({last_results_file_path.replace("/resolve/", "/blob/")}) ' f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. " "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
'You find each in the results and the "latest" split for each eval):\n\n' 'You find each in the results and the "latest" split for each eval):\n\n'
f"```python\n{results_string}\n```" f"```python\n{results_string}\n```"
......
...@@ -225,7 +225,7 @@ class WandbLogger: ...@@ -225,7 +225,7 @@ class WandbLogger:
instance = [x["arguments"][0][0] for x in data] instance = [x["arguments"][0][0] for x in data]
labels = [x["arguments"][0][1] for x in data] labels = [x["arguments"][0][1] for x in data]
resps = [ resps = [
f'log probability of continuation is {x["resps"][0][0][0]} ' f"log probability of continuation is {x['resps'][0][0][0]} "
+ "\n\n" + "\n\n"
+ "continuation will {} generated with greedy sampling".format( + "continuation will {} generated with greedy sampling".format(
"not be" if not x["resps"][0][0][1] else "be" "not be" if not x["resps"][0][0][1] else "be"
...@@ -233,7 +233,7 @@ class WandbLogger: ...@@ -233,7 +233,7 @@ class WandbLogger:
for x in data for x in data
] ]
filtered_resps = [ filtered_resps = [
f'log probability of continuation is {x["filtered_resps"][0][0]} ' f"log probability of continuation is {x['filtered_resps'][0][0]} "
+ "\n\n" + "\n\n"
+ "continuation will {} generated with greedy sampling".format( + "continuation will {} generated with greedy sampling".format(
"not be" if not x["filtered_resps"][0][1] else "be" "not be" if not x["filtered_resps"][0][1] else "be"
......
...@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM): ...@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM):
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests""" """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr): if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...] # for chat completions we need to decode the json string to list[dict,...]
assert ( assert self._batch_size == 1, (
self._batch_size == 1 "non-tokenized chat requests are only supported with batch_size=1"
), "non-tokenized chat requests are only supported with batch_size=1" )
# list[dict["role":..., "content":...],...] # list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt) return json.loads(messages[0].prompt)
...@@ -506,9 +506,9 @@ class TemplateAPI(TemplateLM): ...@@ -506,9 +506,9 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API") return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
assert ( assert self.tokenizer is not None, (
self.tokenizer is not None "Tokenizer is required for loglikelihood tasks to compute context lengths."
), "Tokenizer is required for loglikelihood tasks to compute context lengths." )
res = [] res = []
def _collate(req: LogLikelihoodInputs): def _collate(req: LogLikelihoodInputs):
......
...@@ -51,9 +51,9 @@ class HFMultimodalLM(HFLM): ...@@ -51,9 +51,9 @@ class HFMultimodalLM(HFLM):
# modify init behavior. # modify init behavior.
super().__init__(pretrained, **kwargs) super().__init__(pretrained, **kwargs)
assert ( assert self.batch_size != "auto", (
self.batch_size != "auto" "Batch size 'auto' is not yet supported for hf-multimodal models."
), "Batch size 'auto' is not yet supported for hf-multimodal models." )
self.chat_applied: bool = False self.chat_applied: bool = False
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case # TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
...@@ -73,9 +73,9 @@ class HFMultimodalLM(HFLM): ...@@ -73,9 +73,9 @@ class HFMultimodalLM(HFLM):
or getattr(self.config, "image_token_index", None) or getattr(self.config, "image_token_index", None)
) )
) )
assert ( assert self.image_token_id is not None, (
self.image_token_id is not None "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." )
# get the string this token ID corresponds to # get the string this token ID corresponds to
self.image_token = self.tok_decode( self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False [self.image_token_id], skip_special_tokens=False
......
...@@ -99,7 +99,9 @@ class HFLM(TemplateLM): ...@@ -99,7 +99,9 @@ class HFLM(TemplateLM):
eval_logger.warning( eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
) )
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" assert not parallelize, (
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
)
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
self._config = self._model.config self._config = self._model.config
...@@ -571,9 +573,9 @@ class HFLM(TemplateLM): ...@@ -571,9 +573,9 @@ class HFLM(TemplateLM):
if not autogptq and not gptqmodel: if not autogptq and not gptqmodel:
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
assert ( assert transformers.__version__ >= "4.30.0", (
transformers.__version__ >= "4.30.0" "load_in_4bit requires transformers >= 4.30.0"
), "load_in_4bit requires transformers >= 4.30.0" )
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None): if model_kwargs.get("bnb_4bit_compute_dtype", None):
...@@ -905,16 +907,16 @@ class HFLM(TemplateLM): ...@@ -905,16 +907,16 @@ class HFLM(TemplateLM):
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.backend == "causal": if self.backend == "causal":
assert ( assert contlen and inplen, (
contlen and inplen "Must pass input len and cont. len to select scored logits for causal LM"
), "Must pass input len and cont. len to select scored logits for causal LM" )
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.backend == "seq2seq": elif self.backend == "seq2seq":
assert ( assert contlen and not inplen, (
contlen and not inplen "Selecting scored logits for Seq2SeqLM requires only cont. len"
), "Selecting scored logits for Seq2SeqLM requires only cont. len" )
# only discard right-padding. # only discard right-padding.
# the logits input to this fn only contain decoder-side tokens. # the logits input to this fn only contain decoder-side tokens.
logits = logits[:contlen] logits = logits[:contlen]
...@@ -1329,9 +1331,9 @@ class HFLM(TemplateLM): ...@@ -1329,9 +1331,9 @@ class HFLM(TemplateLM):
if self.backend == "causal": if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
assert ( assert max_ctx_len > 0, (
max_ctx_len > 0 f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
), f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})." )
elif self.backend == "seq2seq": elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length # max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length max_ctx_len = self.max_length
......
...@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM): ...@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM):
"Only float16/bfloat16/float32 are supported." "Only float16/bfloat16/float32 are supported."
) )
print(f"{'='*20} \n exporting model to neuron") print(f"{'=' * 20} \n exporting model to neuron")
self.model = CustomNeuronModelForCausalLM.from_pretrained( self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM): ...@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM):
) )
neuron_config = self.model.config.neuron neuron_config = self.model.config.neuron
print( print(
f"SUCCESS: neuron model exported with config {neuron_config}. \n {'='*20}" f"SUCCESS: neuron model exported with config {neuron_config}. \n {'=' * 20}"
) )
else: else:
print( print(f"{'=' * 20} \n loading neuron model with config {neuron_config}...")
f"{'='*20} \n loading neuron model with config" f" {neuron_config}..."
)
self.model = CustomNeuronModelForCausalLM.from_pretrained( self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
print(f"SUCCESS: neuron model loaded. \n {'='*20}") print(f"SUCCESS: neuron model loaded. \n {'=' * 20}")
self.truncation = truncation self.truncation = truncation
...@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM): ...@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM):
) )
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(self, logits, contlen=None, inplen=None):
assert ( assert contlen and inplen, (
contlen and inplen "Must pass input len and cont. len to select scored logits for causal LM"
), "Must pass input len and cont. len to select scored logits for causal LM" )
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
......
...@@ -134,9 +134,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -134,9 +134,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
eos=None, eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert type(messages) is not str, (
type(messages) is not str "chat-completions require the --apply_chat_template flag."
), "chat-completions require the --apply_chat_template flag." )
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
...@@ -208,13 +208,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -208,13 +208,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
return key return key
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
assert ( assert self.model in [
self.model "babbage-002",
in [ "davinci-002",
"babbage-002", ], (
"davinci-002", f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
] )
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
...@@ -265,9 +264,9 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -265,9 +264,9 @@ class OpenAIChatCompletion(LocalChatCompletion):
eos="<|endoftext|>", eos="<|endoftext|>",
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert type(messages) is not str, (
type(messages) is not str "chat-completions require the --apply_chat_template flag."
), "chat-completions require the --apply_chat_template flag." )
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
......
...@@ -21,9 +21,9 @@ class IPEXLM(HFLM): ...@@ -21,9 +21,9 @@ class IPEXLM(HFLM):
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# currently only supports causal models # currently only supports causal models
assert ( assert kwargs["backend"] == "causal", (
kwargs["backend"] == "causal" "Currently, only IPEXModelForCausalLM is supported."
), "Currently, only IPEXModelForCausalLM is supported." )
super().__init__( super().__init__(
backend=kwargs.pop("backend", "causal"), backend=kwargs.pop("backend", "causal"),
......
...@@ -29,9 +29,9 @@ class OptimumLM(HFLM): ...@@ -29,9 +29,9 @@ class OptimumLM(HFLM):
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# optimum currently only supports causal models # optimum currently only supports causal models
assert ( assert kwargs["backend"] == "causal", (
kwargs["backend"] == "causal" "Currently, only OVModelForCausalLM is supported."
), "Currently, only OVModelForCausalLM is supported." )
self.openvino_device = device self.openvino_device = device
......
...@@ -155,9 +155,9 @@ def pad_and_concat( ...@@ -155,9 +155,9 @@ def pad_and_concat(
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
""" """
assert ( assert padding_side == "left" or padding_side == "right", (
padding_side == "left" or padding_side == "right" f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" )
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2: if len(tensor.shape) == 2:
......
...@@ -76,9 +76,9 @@ class VLLM(TemplateLM): ...@@ -76,9 +76,9 @@ class VLLM(TemplateLM):
) )
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
assert ( assert max_length is None or max_model_len is None, (
max_length is None or max_model_len is None "Either max_length or max_model_len may be provided, but not both"
), "Either max_length or max_model_len may be provided, but not both" )
self._max_length = max_model_len if max_model_len is not None else max_length self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
...@@ -142,9 +142,9 @@ class VLLM(TemplateLM): ...@@ -142,9 +142,9 @@ class VLLM(TemplateLM):
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
if lora_local_path is not None: if lora_local_path is not None:
assert parse_version(version("vllm")) > parse_version( assert parse_version(version("vllm")) > parse_version("0.3.0"), (
"0.3.0" "lora adapters only compatible with vllm > v0.3.0."
), "lora adapters only compatible with vllm > v0.3.0." )
self.lora_request = LoRARequest("finetuned", 1, lora_local_path) self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
else: else:
self.lora_request = None self.lora_request = None
......
...@@ -41,4 +41,4 @@ def doc_to_text(doc): ...@@ -41,4 +41,4 @@ def doc_to_text(doc):
def doc_to_choice(doc): def doc_to_choice(doc):
return [alpa[i][0] for i in range(5) if doc[f"Option {i+1}"]] return [alpa[i][0] for i in range(5) if doc[f"Option {i + 1}"]]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment