Commit bc4b922c authored by Baber's avatar Baber
Browse files

Merge branch 'main' into llama

# Conflicts:
#	lm_eval/tasks/llama3/README.md
parents 748eb47e b2c090cc
...@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter ...@@ -8,12 +8,17 @@ from lm_eval.api.registry import register_filter
@register_filter("regex") @register_filter("regex")
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """A filter that extracts values from text using regex pattern matching.
This filter applies a regex pattern to each model response and extracts matched values.
If no match is found, returns a fallback value. Useful for extracting structured data
(like numbers) from unstructured model outputs.
"""
def __init__( def __init__(
self, self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)", regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0, group_select: int = 0,
fallback: str = "[invalid]", fallback: str = "[invalid]",
) -> None: ) -> None:
""" """
...@@ -25,7 +30,7 @@ class RegexFilter(Filter): ...@@ -25,7 +30,7 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -55,12 +60,9 @@ class RegexFilter(Filter): ...@@ -55,12 +60,9 @@ class RegexFilter(Filter):
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
""" """ """Filters out leading whitespace from responses."""
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -105,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -105,7 +107,7 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs): def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -164,7 +166,7 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -164,7 +166,7 @@ class MultiChoiceRegexFilter(RegexFilter):
fallback_regex = re.compile("|".join(fallback_regexes)) fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile( without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})" rf":[\s]*({without_paren_fallback_regex})"
) )
filtered = [] filtered = []
......
...@@ -34,9 +34,9 @@ class TakeKFilter(Filter): ...@@ -34,9 +34,9 @@ class TakeKFilter(Filter):
# need resp to be subscriptable to check below # need resp to be subscriptable to check below
resps = list(resps) resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k # check we have at least k responses per doc, else we can't take the first k
assert ( assert len(resps[0]) >= self.k, (
len(resps[0]) >= self.k f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
), f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ." )
return map(lambda r: r[: self.k], resps) return map(lambda r: r[: self.k], resps)
......
...@@ -43,9 +43,9 @@ class MapFilter(Filter): ...@@ -43,9 +43,9 @@ class MapFilter(Filter):
""" """
if mapping_dict is None: if mapping_dict is None:
mapping_dict = {} mapping_dict = {}
assert isinstance( assert isinstance(mapping_dict, dict), (
mapping_dict, dict "Provided mapping_dict is not a dictionary"
), "Provided mapping_dict is not a dictionary" )
self.mapping_dict = mapping_dict self.mapping_dict = mapping_dict
self.default_value = default_value self.default_value = default_value
......
...@@ -488,7 +488,7 @@ class EvaluationTracker: ...@@ -488,7 +488,7 @@ class EvaluationTracker:
else: else:
dataset_summary += f"{self.general_config_tracker.model_name}\n" dataset_summary += f"{self.general_config_tracker.model_name}\n"
dataset_summary += ( dataset_summary += (
f"The dataset is composed of {len(card_metadata)-1} configuration(s), each one corresponding to one of the evaluated task.\n\n" f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each " f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n' 'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
'An additional configuration "results" store all the aggregated results of the run.\n\n' 'An additional configuration "results" store all the aggregated results of the run.\n\n'
...@@ -501,7 +501,7 @@ class EvaluationTracker: ...@@ -501,7 +501,7 @@ class EvaluationTracker:
) )
dataset_summary += ( dataset_summary += (
"## Latest results\n\n" "## Latest results\n\n"
f'These are the [latest results from run {latest_datetime}]({last_results_file_path.replace("/resolve/", "/blob/")}) ' f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. " "(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
'You find each in the results and the "latest" split for each eval):\n\n' 'You find each in the results and the "latest" split for each eval):\n\n'
f"```python\n{results_string}\n```" f"```python\n{results_string}\n```"
......
...@@ -48,6 +48,9 @@ class WandbLogger: ...@@ -48,6 +48,9 @@ class WandbLogger:
self.wandb_args: Dict[str, Any] = kwargs self.wandb_args: Dict[str, Any] = kwargs
# pop the step key from the args to save for all logging calls
self.step = self.wandb_args.pop("step", None)
# initialize a W&B run # initialize a W&B run
if wandb.run is None: if wandb.run is None:
self.run = wandb.init(**self.wandb_args) self.run = wandb.init(**self.wandb_args)
...@@ -152,11 +155,11 @@ class WandbLogger: ...@@ -152,11 +155,11 @@ class WandbLogger:
# log the complete eval result to W&B Table # log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results") table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table}) self.run.log({"evaluation/eval_results": table}, step=self.step)
if "groups" in self.results.keys(): if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups") table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table}) self.run.log({"evaluation/group_eval_results": table}, step=self.step)
def _log_results_as_artifact(self) -> None: def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B.""" """Log results as JSON artifact to W&B."""
...@@ -174,13 +177,13 @@ class WandbLogger: ...@@ -174,13 +177,13 @@ class WandbLogger:
"""Log evaluation results to W&B.""" """Log evaluation results to W&B."""
# Log configs to wandb # Log configs to wandb
configs = self._get_config() configs = self._get_config()
self.run.config.update(configs) self.run.config.update(configs, allow_val_change=self.step is not None)
wandb_summary, self.wandb_results = self._sanitize_results_dict() wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed # update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary) self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb # Log the evaluation metrics to wandb
self.run.log(self.wandb_results) self.run.log(self.wandb_results, step=self.step)
# Log the evaluation metrics as W&B Table # Log the evaluation metrics as W&B Table
self._log_results_as_table() self._log_results_as_table()
# Log the results dict as json to W&B Artifacts # Log the results dict as json to W&B Artifacts
...@@ -222,7 +225,7 @@ class WandbLogger: ...@@ -222,7 +225,7 @@ class WandbLogger:
instance = [x["arguments"][0][0] for x in data] instance = [x["arguments"][0][0] for x in data]
labels = [x["arguments"][0][1] for x in data] labels = [x["arguments"][0][1] for x in data]
resps = [ resps = [
f'log probability of continuation is {x["resps"][0][0][0]} ' f"log probability of continuation is {x['resps'][0][0][0]} "
+ "\n\n" + "\n\n"
+ "continuation will {} generated with greedy sampling".format( + "continuation will {} generated with greedy sampling".format(
"not be" if not x["resps"][0][0][1] else "be" "not be" if not x["resps"][0][0][1] else "be"
...@@ -230,7 +233,7 @@ class WandbLogger: ...@@ -230,7 +233,7 @@ class WandbLogger:
for x in data for x in data
] ]
filtered_resps = [ filtered_resps = [
f'log probability of continuation is {x["filtered_resps"][0][0]} ' f"log probability of continuation is {x['filtered_resps'][0][0]} "
+ "\n\n" + "\n\n"
+ "continuation will {} generated with greedy sampling".format( + "continuation will {} generated with greedy sampling".format(
"not be" if not x["filtered_resps"][0][1] else "be" "not be" if not x["filtered_resps"][0][1] else "be"
...@@ -329,7 +332,7 @@ class WandbLogger: ...@@ -329,7 +332,7 @@ class WandbLogger:
# log the samples as a W&B Table # log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name)) df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df}) self.run.log({f"{task_name}_eval_results": df}, step=self.step)
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
...@@ -348,4 +351,4 @@ class WandbLogger: ...@@ -348,4 +351,4 @@ class WandbLogger:
# log the samples as a json file as W&B Artifact # log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name) self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df}) self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
...@@ -11,6 +11,7 @@ from . import ( ...@@ -11,6 +11,7 @@ from . import (
neuralmagic, neuralmagic,
neuron_optimum, neuron_optimum,
openai_completions, openai_completions,
optimum_ipex,
optimum_lm, optimum_lm,
textsynth, textsynth,
vllm_causallms, vllm_causallms,
......
...@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM): ...@@ -195,9 +195,9 @@ class TemplateAPI(TemplateLM):
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests""" """Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr): if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...] # for chat completions we need to decode the json string to list[dict,...]
assert ( assert self._batch_size == 1, (
self._batch_size == 1 "non-tokenized chat requests are only supported with batch_size=1"
), "non-tokenized chat requests are only supported with batch_size=1" )
# list[dict["role":..., "content":...],...] # list[dict["role":..., "content":...],...]
return json.loads(messages[0].prompt) return json.loads(messages[0].prompt)
...@@ -253,12 +253,15 @@ class TemplateAPI(TemplateLM): ...@@ -253,12 +253,15 @@ class TemplateAPI(TemplateLM):
return "" return ""
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]] self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]: ) -> Union[str, JsonChatStr]:
"""Applies a chat template to a list of chat history between user and model.""" """Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests: if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
else: else:
# bit of a hack. We'll load back before sending to the API # bit of a hack. We'll load back before sending to the API
...@@ -503,9 +506,9 @@ class TemplateAPI(TemplateLM): ...@@ -503,9 +506,9 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API") return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
assert ( assert self.tokenizer is not None, (
self.tokenizer is not None "Tokenizer is required for loglikelihood tasks to compute context lengths."
), "Tokenizer is required for loglikelihood tasks to compute context lengths." )
res = [] res = []
def _collate(req: LogLikelihoodInputs): def _collate(req: LogLikelihoodInputs):
......
...@@ -51,9 +51,9 @@ class HFMultimodalLM(HFLM): ...@@ -51,9 +51,9 @@ class HFMultimodalLM(HFLM):
# modify init behavior. # modify init behavior.
super().__init__(pretrained, **kwargs) super().__init__(pretrained, **kwargs)
assert ( assert self.batch_size != "auto", (
self.batch_size != "auto" "Batch size 'auto' is not yet supported for hf-multimodal models."
), "Batch size 'auto' is not yet supported for hf-multimodal models." )
self.chat_applied: bool = False self.chat_applied: bool = False
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case # TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
...@@ -73,9 +73,9 @@ class HFMultimodalLM(HFLM): ...@@ -73,9 +73,9 @@ class HFMultimodalLM(HFLM):
or getattr(self.config, "image_token_index", None) or getattr(self.config, "image_token_index", None)
) )
) )
assert ( assert self.image_token_id is not None, (
self.image_token_id is not None "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." )
# get the string this token ID corresponds to # get the string this token ID corresponds to
self.image_token = self.tok_decode( self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False [self.image_token_id], skip_special_tokens=False
...@@ -200,7 +200,9 @@ class HFMultimodalLM(HFLM): ...@@ -200,7 +200,9 @@ class HFMultimodalLM(HFLM):
return context_enc, continuation_enc, image_enc return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
self.chat_applied = True self.chat_applied = True
if not self.interleave: if not self.interleave:
for content in chat_history: for content in chat_history:
...@@ -250,7 +252,9 @@ class HFMultimodalLM(HFLM): ...@@ -250,7 +252,9 @@ class HFMultimodalLM(HFLM):
) )
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
...@@ -90,6 +90,7 @@ class HFLM(TemplateLM): ...@@ -90,6 +90,7 @@ class HFLM(TemplateLM):
delta: Optional[str] = None, delta: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False, gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -98,7 +99,9 @@ class HFLM(TemplateLM): ...@@ -98,7 +99,9 @@ class HFLM(TemplateLM):
eval_logger.warning( eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
) )
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" assert not parallelize, (
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
)
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
self._config = self._model.config self._config = self._model.config
...@@ -164,6 +167,7 @@ class HFLM(TemplateLM): ...@@ -164,6 +167,7 @@ class HFLM(TemplateLM):
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
) )
# determine which of 'causal' and 'seq2seq' backends to use for HF models # determine which of 'causal' and 'seq2seq' backends to use for HF models
...@@ -178,6 +182,7 @@ class HFLM(TemplateLM): ...@@ -178,6 +182,7 @@ class HFLM(TemplateLM):
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast_tokenizer=use_fast_tokenizer, use_fast_tokenizer=use_fast_tokenizer,
gguf_file=gguf_file,
) )
# if we passed `pretrained` as a string, initialize our model now # if we passed `pretrained` as a string, initialize our model now
...@@ -196,6 +201,7 @@ class HFLM(TemplateLM): ...@@ -196,6 +201,7 @@ class HFLM(TemplateLM):
delta=delta, delta=delta,
autogptq=autogptq, autogptq=autogptq,
gptqmodel=gptqmodel, gptqmodel=gptqmodel,
gguf_file=gguf_file,
**kwargs, **kwargs,
) )
...@@ -508,12 +514,14 @@ class HFLM(TemplateLM): ...@@ -508,12 +514,14 @@ class HFLM(TemplateLM):
pretrained: str, pretrained: str,
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
gguf_file: Optional[str] = None,
) -> None: ) -> None:
"""Return the model config for HuggingFace models""" """Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
) )
def _create_model( def _create_model(
...@@ -535,6 +543,7 @@ class HFLM(TemplateLM): ...@@ -535,6 +543,7 @@ class HFLM(TemplateLM):
delta: Optional[str] = None, delta: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False, gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -564,9 +573,9 @@ class HFLM(TemplateLM): ...@@ -564,9 +573,9 @@ class HFLM(TemplateLM):
if not autogptq and not gptqmodel: if not autogptq and not gptqmodel:
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
assert ( assert transformers.__version__ >= "4.30.0", (
transformers.__version__ >= "4.30.0" "load_in_4bit requires transformers >= 4.30.0"
), "load_in_4bit requires transformers >= 4.30.0" )
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None): if model_kwargs.get("bnb_4bit_compute_dtype", None):
...@@ -579,6 +588,7 @@ class HFLM(TemplateLM): ...@@ -579,6 +588,7 @@ class HFLM(TemplateLM):
revision=revision, revision=revision,
torch_dtype=get_dtype(dtype), torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
gguf_file=gguf_file,
**model_kwargs, **model_kwargs,
) )
else: else:
...@@ -676,6 +686,7 @@ class HFLM(TemplateLM): ...@@ -676,6 +686,7 @@ class HFLM(TemplateLM):
revision: Optional[str] = "main", revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
gguf_file: Optional[str] = None,
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
...@@ -683,14 +694,21 @@ class HFLM(TemplateLM): ...@@ -683,14 +694,21 @@ class HFLM(TemplateLM):
Create a tokenizer object corresponding to the correct Create a tokenizer object corresponding to the correct
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
""" """
kwargs = {
"revision": revision,
"trust_remote_code": trust_remote_code,
}
# gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
if gguf_file is not None:
kwargs["gguf_file"] = gguf_file
else:
kwargs["use_fast"] = use_fast_tokenizer
if tokenizer: if tokenizer:
if isinstance(tokenizer, str): if isinstance(tokenizer, str):
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer, tokenizer, **kwargs
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
) )
else: else:
assert isinstance( assert isinstance(
...@@ -705,10 +723,7 @@ class HFLM(TemplateLM): ...@@ -705,10 +723,7 @@ class HFLM(TemplateLM):
# get the HF hub name via accessor on model # get the HF hub name via accessor on model
model_name = self.model.name_or_path model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name, model_name, **kwargs
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
) )
return None return None
...@@ -818,6 +833,12 @@ class HFLM(TemplateLM): ...@@ -818,6 +833,12 @@ class HFLM(TemplateLM):
**add_special_tokens, **add_special_tokens,
) )
if left_truncate_len: if left_truncate_len:
original_lengths = encoding["input_ids"].size(1)
if original_lengths > left_truncate_len:
eval_logger.warn(
f"Left truncation applied. Original sequence length was {original_lengths}, "
f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
)
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][ encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len: :, -left_truncate_len:
...@@ -886,16 +907,16 @@ class HFLM(TemplateLM): ...@@ -886,16 +907,16 @@ class HFLM(TemplateLM):
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.backend == "causal": if self.backend == "causal":
assert ( assert contlen and inplen, (
contlen and inplen "Must pass input len and cont. len to select scored logits for causal LM"
), "Must pass input len and cont. len to select scored logits for causal LM" )
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.backend == "seq2seq": elif self.backend == "seq2seq":
assert ( assert contlen and not inplen, (
contlen and not inplen "Selecting scored logits for Seq2SeqLM requires only cont. len"
), "Selecting scored logits for Seq2SeqLM requires only cont. len" )
# only discard right-padding. # only discard right-padding.
# the logits input to this fn only contain decoder-side tokens. # the logits input to this fn only contain decoder-side tokens.
logits = logits[:contlen] logits = logits[:contlen]
...@@ -905,8 +926,6 @@ class HFLM(TemplateLM): ...@@ -905,8 +926,6 @@ class HFLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -915,10 +934,17 @@ class HFLM(TemplateLM): ...@@ -915,10 +934,17 @@ class HFLM(TemplateLM):
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
for (string,) in tqdm( # First, collect all windows from all requests
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0)) all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
): ):
rolling_token_windows = list( rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
...@@ -931,37 +957,55 @@ class HFLM(TemplateLM): ...@@ -931,37 +957,55 @@ class HFLM(TemplateLM):
) )
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows] windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank] # Store windows with their request index
if pad_amnt > 0: all_windows.extend((req_idx, window) for window in windows)
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] request_window_counts.append(len(windows))
string_nll = self._loglikelihood_tokens( # Handle distributed case padding
requests=rolling_token_windows, pad_amnt = 0
disable_tqdm=True, if self.world_size > 1:
override_bs=adaptive_batch_size, mytensor = torch.tensor(len(all_windows), device=self.device)
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
all_windows += pad_amnt * [all_windows[0]]
all_nlls = []
batch_size = adaptive_batch_size or self.batch_size
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
batch_nlls = self._loglikelihood_tokens(
requests=batch_windows,
disable_tqdm=False,
override_bs=len(batch_windows),
) )
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
if (self.world_size > 1) and (pad_amnt > 0): # Remove padding if necessary
string_nll = [x[0] for x in string_nll[:-pad_amnt]] if (self.world_size > 1) and (pad_amnt > 0):
else: all_nlls = all_nlls[:-pad_amnt]
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) # Reconstruct per-request loglikelihoods
loglikelihoods.append(string_nll) loglikelihoods = []
current_idx = 0
# cache this loglikelihood_rolling request for window_count in request_window_counts:
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) # Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods return loglikelihoods
...@@ -1073,6 +1117,13 @@ class HFLM(TemplateLM): ...@@ -1073,6 +1117,13 @@ class HFLM(TemplateLM):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
if self.backend == "causal": if self.backend == "causal":
total_length = len(context_enc) + len(continuation_enc)
if total_length > self.max_length + 1:
eval_logger.warn(
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
f"exceeds model's maximum length ({self.max_length}). "
f"Truncating {total_length - self.max_length + 1} tokens from the left."
)
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long, dtype=torch.long,
...@@ -1280,6 +1331,9 @@ class HFLM(TemplateLM): ...@@ -1280,6 +1331,9 @@ class HFLM(TemplateLM):
if self.backend == "causal": if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
assert max_ctx_len > 0, (
f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
)
elif self.backend == "seq2seq": elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length # max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length max_ctx_len = self.max_length
...@@ -1330,13 +1384,18 @@ class HFLM(TemplateLM): ...@@ -1330,13 +1384,18 @@ class HFLM(TemplateLM):
return res return res
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
""" """
try: try:
chat_templated = self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
except jinja2.exceptions.TemplateError: except jinja2.exceptions.TemplateError:
eval_logger.warning( eval_logger.warning(
...@@ -1344,7 +1403,10 @@ class HFLM(TemplateLM): ...@@ -1344,7 +1403,10 @@ class HFLM(TemplateLM):
) )
chat_history = [msg for msg in chat_history if msg["role"] != "system"] chat_history = [msg for msg in chat_history if msg["role"] != "system"]
chat_templated = self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
return chat_templated return chat_templated
......
...@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM): ...@@ -206,7 +206,7 @@ class NEURON_HF(TemplateLM):
"Only float16/bfloat16/float32 are supported." "Only float16/bfloat16/float32 are supported."
) )
print(f"{'='*20} \n exporting model to neuron") print(f"{'=' * 20} \n exporting model to neuron")
self.model = CustomNeuronModelForCausalLM.from_pretrained( self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM): ...@@ -220,19 +220,17 @@ class NEURON_HF(TemplateLM):
) )
neuron_config = self.model.config.neuron neuron_config = self.model.config.neuron
print( print(
f"SUCCESS: neuron model exported with config {neuron_config}. \n {'='*20}" f"SUCCESS: neuron model exported with config {neuron_config}. \n {'=' * 20}"
) )
else: else:
print( print(f"{'=' * 20} \n loading neuron model with config {neuron_config}...")
f"{'='*20} \n loading neuron model with config" f" {neuron_config}..."
)
self.model = CustomNeuronModelForCausalLM.from_pretrained( self.model = CustomNeuronModelForCausalLM.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
print(f"SUCCESS: neuron model loaded. \n {'='*20}") print(f"SUCCESS: neuron model loaded. \n {'=' * 20}")
self.truncation = truncation self.truncation = truncation
...@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM): ...@@ -353,9 +351,9 @@ class NEURON_HF(TemplateLM):
) )
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(self, logits, contlen=None, inplen=None):
assert ( assert contlen and inplen, (
contlen and inplen "Must pass input len and cont. len to select scored logits for causal LM"
), "Must pass input len and cont. len to select scored logits for causal LM" )
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
......
import os import os
from functools import cached_property from functools import cached_property
from operator import itemgetter
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
...@@ -68,7 +69,9 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -68,7 +69,9 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
for choice, ctxlen in zip(out["choices"], ctxlens): for choice, ctxlen in zip(
sorted(out["choices"], key=itemgetter("index")), ctxlens
):
assert ctxlen > 0, "Context length must be greater than 0" assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1]) logprobs = sum(choice["logprobs"]["token_logprobs"][ctxlen:-1])
tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1] tokens_logprobs = choice["logprobs"]["token_logprobs"][ctxlen:-1]
...@@ -87,8 +90,10 @@ class LocalCompletionsAPI(TemplateAPI): ...@@ -87,8 +90,10 @@ class LocalCompletionsAPI(TemplateAPI):
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]: for choices in out["choices"]:
res.append(choices["text"]) tmp[choices["index"]] = choices["text"]
res = res + tmp
return res return res
@property @property
...@@ -129,9 +134,9 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -129,9 +134,9 @@ class LocalChatCompletion(LocalCompletionsAPI):
eos=None, eos=None,
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert type(messages) is not str, (
type(messages) is not str "chat-completions require the --apply_chat_template flag."
), "chat-completions require the --apply_chat_template flag." )
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
...@@ -157,8 +162,10 @@ class LocalChatCompletion(LocalCompletionsAPI): ...@@ -157,8 +162,10 @@ class LocalChatCompletion(LocalCompletionsAPI):
if not isinstance(outputs, list): if not isinstance(outputs, list):
outputs = [outputs] outputs = [outputs]
for out in outputs: for out in outputs:
tmp = [None] * len(out["choices"])
for choices in out["choices"]: for choices in out["choices"]:
res.append(choices["message"]["content"]) tmp[choices["index"]] = choices["message"]["content"]
res = res + tmp
return res return res
def tok_encode( def tok_encode(
...@@ -201,13 +208,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI): ...@@ -201,13 +208,12 @@ class OpenAICompletionsAPI(LocalCompletionsAPI):
return key return key
def loglikelihood(self, requests, **kwargs): def loglikelihood(self, requests, **kwargs):
assert ( assert self.model in [
self.model "babbage-002",
in [ "davinci-002",
"babbage-002", ], (
"davinci-002", f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
] )
), f"Prompt loglikelihoods are only supported by OpenAI's API for {['babbage-002', 'davinci-002']}."
return super().loglikelihood(requests, **kwargs) return super().loglikelihood(requests, **kwargs)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
...@@ -258,9 +264,9 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -258,9 +264,9 @@ class OpenAIChatCompletion(LocalChatCompletion):
eos="<|endoftext|>", eos="<|endoftext|>",
**kwargs, **kwargs,
) -> dict: ) -> dict:
assert ( assert type(messages) is not str, (
type(messages) is not str "chat-completions require the --apply_chat_template flag."
), "chat-completions require the --apply_chat_template flag." )
gen_kwargs.pop("do_sample", False) gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs: if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens") max_tokens = gen_kwargs.pop("max_tokens")
......
from importlib.util import find_spec
from lm_eval import utils
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import get_dtype
eval_logger = utils.eval_logger
@register_model("ipex")
class IPEXLM(HFLM):
"""
using the HuggingFace transformers + optimum-intel ipex backend, can run on intel cpu and intel gpu
"""
def __init__(
self,
**kwargs,
) -> None:
if "backend" in kwargs:
# currently only supports causal models
assert kwargs["backend"] == "causal", (
"Currently, only IPEXModelForCausalLM is supported."
)
super().__init__(
backend=kwargs.pop("backend", "causal"),
**kwargs,
)
def _create_model(
self,
pretrained: str,
revision="main",
dtype="auto",
trust_remote_code=False,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# (accelerate naive PP (device_map) options)
parallelize=False,
gpus=None,
max_memory_per_gpu=None,
max_cpu_memory=None,
offload_folder="./offload",
# PEFT, delta weights and quantization options
peft=None,
delta=None,
autogptq=False,
gptqmodel=False,
**kwargs,
) -> None:
if not find_spec("optimum"):
raise ModuleNotFoundError(
"package `optimum` is not installed. Please install it via `pip install optimum[ipex]`"
)
else:
from optimum.intel import IPEXModelForCausalLM
model_kwargs = kwargs if kwargs else {}
model_kwargs.update(
self._get_accelerate_args(
parallelize=parallelize,
device_map=kwargs.get("device_map", None),
max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory,
offload_folder=offload_folder,
gpus=gpus,
)
)
self._model = IPEXModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
**model_kwargs,
)
...@@ -29,9 +29,9 @@ class OptimumLM(HFLM): ...@@ -29,9 +29,9 @@ class OptimumLM(HFLM):
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# optimum currently only supports causal models # optimum currently only supports causal models
assert ( assert kwargs["backend"] == "causal", (
kwargs["backend"] == "causal" "Currently, only OVModelForCausalLM is supported."
), "Currently, only OVModelForCausalLM is supported." )
self.openvino_device = device self.openvino_device = device
......
...@@ -155,9 +155,9 @@ def pad_and_concat( ...@@ -155,9 +155,9 @@ def pad_and_concat(
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
""" """
assert ( assert padding_side == "left" or padding_side == "right", (
padding_side == "left" or padding_side == "right" f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" )
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2: if len(tensor.shape) == 2:
......
...@@ -76,9 +76,9 @@ class VLLM(TemplateLM): ...@@ -76,9 +76,9 @@ class VLLM(TemplateLM):
) )
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
assert ( assert max_length is None or max_model_len is None, (
max_length is None or max_model_len is None "Either max_length or max_model_len may be provided, but not both"
), "Either max_length or max_model_len may be provided, but not both" )
self._max_length = max_model_len if max_model_len is not None else max_length self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
...@@ -102,7 +102,7 @@ class VLLM(TemplateLM): ...@@ -102,7 +102,7 @@ class VLLM(TemplateLM):
self.batch_size = ( self.batch_size = (
"auto" "auto"
if isinstance(batch_size, str) and "auto" in batch_size if isinstance(batch_size, str) and "auto" in batch_size
else batch_size else int(batch_size)
) )
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
...@@ -142,9 +142,9 @@ class VLLM(TemplateLM): ...@@ -142,9 +142,9 @@ class VLLM(TemplateLM):
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
if lora_local_path is not None: if lora_local_path is not None:
assert parse_version(version("vllm")) > parse_version( assert parse_version(version("vllm")) > parse_version("0.3.0"), (
"0.3.0" "lora adapters only compatible with vllm > v0.3.0."
), "lora adapters only compatible with vllm > v0.3.0." )
self.lora_request = LoRARequest("finetuned", 1, lora_local_path) self.lora_request = LoRARequest("finetuned", 1, lora_local_path)
else: else:
self.lora_request = None self.lora_request = None
...@@ -184,14 +184,21 @@ class VLLM(TemplateLM): ...@@ -184,14 +184,21 @@ class VLLM(TemplateLM):
def max_gen_toks(self): def max_gen_toks(self):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
""" """
return self.tokenizer.apply_chat_template( chat_templated = self.tokenizer.apply_chat_template(
chat_history, tokenize=False, add_generation_prompt=True chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
return chat_templated
@property @property
def tokenizer_name(self) -> str: def tokenizer_name(self) -> str:
return self.tokenizer.name_or_path.replace("/", "__") return self.tokenizer.name_or_path.replace("/", "__")
...@@ -281,10 +288,21 @@ class VLLM(TemplateLM): ...@@ -281,10 +288,21 @@ class VLLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> List[float]:
loglikelihoods = [] adaptive_batch_size = None
if self.batch_size == "auto":
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm): adaptive_batch_size = len(requests)
rolling_token_windows = list(
# First, collect all windows from all requests
all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -297,20 +315,42 @@ class VLLM(TemplateLM): ...@@ -297,20 +315,42 @@ class VLLM(TemplateLM):
) )
) )
rolling_token_windows = [(None,) + x for x in rolling_token_windows] # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens( # Store windows with their request index
rolling_token_windows, all_windows.extend((req_idx, window) for window in windows)
) request_window_counts.append(len(windows))
# discard is_greedy all_nlls = []
string_nll = [x[0] for x in string_nll] batch_size = adaptive_batch_size or int(self.batch_size)
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
string_nll = sum(string_nll) batch_nlls = self._loglikelihood_tokens(
loglikelihoods.append(string_nll) requests=batch_windows,
disable_tqdm=False,
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
# cache this loglikelihood_rolling request # Reconstruct per-request loglikelihoods
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods return loglikelihoods
......
...@@ -144,7 +144,9 @@ class VLLM_VLM(VLLM): ...@@ -144,7 +144,9 @@ class VLLM_VLM(VLLM):
) )
return outputs return outputs
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
) -> str:
self.chat_applied = True self.chat_applied = True
if not self.interleave: if not self.interleave:
for content in chat_history: for content in chat_history:
...@@ -194,7 +196,9 @@ class VLLM_VLM(VLLM): ...@@ -194,7 +196,9 @@ class VLLM_VLM(VLLM):
) )
return self.processor.apply_chat_template( return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True chat_history,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
) )
def generate_until( def generate_until(
......
This diff is collapsed.
...@@ -9,4 +9,4 @@ aggregate_metric_list: ...@@ -9,4 +9,4 @@ aggregate_metric_list:
- metric: acc - metric: acc
weight_by_size: True weight_by_size: True
metadata: metadata:
version: 0 version: 1
...@@ -6,4 +6,4 @@ aggregate_metric_list: ...@@ -6,4 +6,4 @@ aggregate_metric_list:
- metric: acc - metric: acc
weight_by_size: True weight_by_size: True
metadata: metadata:
version: 0 version: 1
...@@ -6,4 +6,4 @@ aggregate_metric_list: ...@@ -6,4 +6,4 @@ aggregate_metric_list:
- metric: acc - metric: acc
weight_by_size: True weight_by_size: True
metadata: metadata:
version: 0 version: 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment