Commit c4b0c0cb authored by Baber's avatar Baber
Browse files

Merge branch 'main' into metrics

# Conflicts:
#	lm_eval/models/vllm_causallms.py
#	pyproject.toml
parents 6b20ae8c de496b80
...@@ -5,8 +5,9 @@ import traceback ...@@ -5,8 +5,9 @@ import traceback
from typing import Iterator, List, Sequence, Tuple, TypeVar from typing import Iterator, List, Sequence, Tuple, TypeVar
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module.
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
try: try:
import janitor_util import janitor_util
......
...@@ -71,13 +71,6 @@ class SteeredModel(HFLM): ...@@ -71,13 +71,6 @@ class SteeredModel(HFLM):
""" """
HFLM with a steered forward pass. HFLM with a steered forward pass.
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,layer_20/width_16k/canonical,increase dogs,
To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format: To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format:
{ {
...@@ -86,9 +79,17 @@ class SteeredModel(HFLM): ...@@ -86,9 +79,17 @@ class SteeredModel(HFLM):
"steering_coefficient": <float>, "steering_coefficient": <float>,
"action": <Literal["add", "clamp"]>, "action": <Literal["add", "clamp"]>,
"bias": <torch.Tensor | None>, "bias": <torch.Tensor | None>,
"head_index": <int | None>,
}, },
... ...
} }
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,head_index,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,,layer_20/width_16k/canonical,increase dogs,
""" """
super().__init__(pretrained=pretrained, device=device, **kwargs) super().__init__(pretrained=pretrained, device=device, **kwargs)
...@@ -105,27 +106,31 @@ class SteeredModel(HFLM): ...@@ -105,27 +106,31 @@ class SteeredModel(HFLM):
hook_to_steer = {} hook_to_steer = {}
for hookpoint, steer_info in steer_config.items(): for hookpoint, steer_info in steer_config.items():
action = steer_info["action"] action = steer_info["action"]
steering_coefficient = steer_info["steering_coefficient"]
steering_vector = ( steering_vector = (
steer_info["steering_vector"].to(self.device).to(self.model.dtype) steer_info["steering_vector"].to(self.device).to(self.model.dtype)
) )
bias = ( steering_coefficient = float(steer_info.get("steering_coefficient", 1.0))
steer_info["bias"].to(self.device).to(self.model.dtype) head_index = steer_info.get("head_index", None)
if steer_info["bias"] is not None bias = steer_info.get("bias", None)
else None if bias is not None:
) bias = bias.to(self.device).to(self.model.dtype)
if action == "add": if action == "add":
# Steers the model by adding some multiple of a steering vector to all sequence positions. # Steer the model by adding a multiple of a steering vector to all sequence positions.
hook_to_steer[hookpoint] = ( assert bias is None, "Bias is not supported for the `add` action."
lambda acts: acts + steering_coefficient * steering_vector hook_to_steer[hookpoint] = partial(
self.add,
vector=steering_vector * steering_coefficient,
head_index=head_index,
) )
elif action == "clamp": elif action == "clamp":
# Steer the model by clamping the activations to a value in the direction of the steering vector.
hook_to_steer[hookpoint] = partial( hook_to_steer[hookpoint] = partial(
self.clamp, self.clamp,
steering_vector=steering_vector, direction=steering_vector / torch.norm(steering_vector),
value=steering_coefficient, value=steering_coefficient,
bias=bias, bias=bias,
head_index=head_index,
) )
else: else:
raise ValueError(f"Unknown hook type: {action}") raise ValueError(f"Unknown hook type: {action}")
...@@ -195,34 +200,62 @@ class SteeredModel(HFLM): ...@@ -195,34 +200,62 @@ class SteeredModel(HFLM):
return steer_data return steer_data
@classmethod
def add(
cls,
acts: Tensor,
vector: Tensor,
head_index: Optional[int],
):
"""Adds the given vector to the activations.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
vector (Tensor): A vector to add of shape [features]
head_index (int | None): Optional attention head index to add to
"""
if head_index is not None:
acts[:, :, head_index, :] = acts[:, :, head_index, :] + vector
else:
acts = acts + vector
return acts
@classmethod @classmethod
def clamp( def clamp(
cls, cls,
acts: Tensor, acts: Tensor,
steering_vector: Tensor, direction: Tensor,
value: float, value: float,
head_index: Optional[int],
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
): ):
"""Clamps a direction of the activations to be the steering vector * the value. """Clamps the activations to a given value in a specified direction. The direction
must be a unit vector.
Args: Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, features] acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
steering_vector (Tensor): A direction to clamp of shape [features] direction (Tensor): A direction to clamp of shape [features]
value (float): Value to clamp the direction to value (float): Value to clamp the direction to
head_index (int | None): Optional attention head index to clamp
bias (Tensor | None): Optional bias to add to the activations bias (Tensor | None): Optional bias to add to the activations
Returns: Returns:
Tensor: The modified activations with the specified direction clamped Tensor: The modified activations with the specified direction clamped
""" """
if bias is not None: if bias is not None:
acts = acts - bias acts = acts - bias
direction = steering_vector / torch.norm(steering_vector) if head_index is not None:
proj_magnitude = torch.sum(acts * direction, dim=-1, keepdim=True) x = acts[:, :, head_index, :]
orthogonal_component = acts - proj_magnitude * direction proj = (x * direction).sum(dim=-1, keepdim=True)
assert proj == acts @ direction
clamped = orthogonal_component + direction * value clamped = acts.clone()
clamped[:, :, head_index, :] = x + direction * (value - proj)
else:
proj = torch.sum(acts * direction, dim=-1, keepdim=True)
clamped = acts + direction * (value - proj)
if bias is not None: if bias is not None:
return clamped + bias return clamped + bias
......
...@@ -124,14 +124,22 @@ class HFLM(TemplateLM): ...@@ -124,14 +124,22 @@ class HFLM(TemplateLM):
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
gpus = torch.cuda.device_count()
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
self.accelerator = accelerator self.accelerator = accelerator
if "npu" in accelerator.device.type: # Detect device count based on accelerator device type
device_type = accelerator.device.type
if "cuda" in device_type:
gpus = torch.cuda.device_count()
elif "npu" in device_type:
gpus = torch.npu.device_count() gpus = torch.npu.device_count()
elif "xpu" in device_type:
gpus = torch.xpu.device_count()
else:
# Fallback to CUDA count for compatibility
gpus = torch.cuda.device_count()
# using one process with no model parallelism # using one process with no model parallelism
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
...@@ -141,6 +149,7 @@ class HFLM(TemplateLM): ...@@ -141,6 +149,7 @@ class HFLM(TemplateLM):
+ [f"cuda:{i}" for i in range(gpus)] + [f"cuda:{i}" for i in range(gpus)]
+ ["mps", "mps:0"] + ["mps", "mps:0"]
+ [f"npu:{i}" for i in range(gpus)] + [f"npu:{i}" for i in range(gpus)]
+ [f"xpu:{i}" for i in range(gpus)]
) )
if device and device in device_list: if device and device in device_list:
self._device = torch.device(device) self._device = torch.device(device)
...@@ -679,10 +688,19 @@ class HFLM(TemplateLM): ...@@ -679,10 +688,19 @@ class HFLM(TemplateLM):
"0.4.0" "0.4.0"
): ):
raise AssertionError("load_in_4bit requires peft >= 0.4.0") raise AssertionError("load_in_4bit requires peft >= 0.4.0")
if self._model.config.vocab_size != len(self.tokenizer):
# Compatible with Gemma3 (multimodal) and old models
if hasattr(self._model.config, "text_config") and hasattr(
self._model.config.text_config, "vocab_size"
):
vocab_size = self._model.config.text_config.vocab_size
else:
vocab_size = self._model.config.vocab_size
if vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens # resize model for LoRAs with added tokens
eval_logger.info( eval_logger.info(
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." f"Model config indicates vocab_size='{vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
) )
self._model.resize_token_embeddings(len(self.tokenizer)) self._model.resize_token_embeddings(len(self.tokenizer))
self._model = PeftModel.from_pretrained( self._model = PeftModel.from_pretrained(
......
...@@ -290,7 +290,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -290,7 +290,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
"seed": seed, "seed": seed,
**gen_kwargs, **gen_kwargs,
} }
if "o1" in self.model: if "o1" in self.model or "5" in self.model:
output.pop("stop") output.pop("stop")
output["temperature"] = 1 output["temperature"] = 1
elif "o3" in self.model: elif "o3" in self.model:
......
...@@ -28,9 +28,8 @@ class OptimumLM(HFLM): ...@@ -28,9 +28,8 @@ class OptimumLM(HFLM):
**kwargs, **kwargs,
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# optimum currently only supports causal models assert kwargs["backend"] in ["causal", "seq2seq"], (
assert kwargs["backend"] == "causal", ( "Currently, only OVModelForCausalLM or OVModelForSeq2SeqLM are supported."
"Currently, only OVModelForCausalLM is supported."
) )
self.openvino_device = device self.openvino_device = device
...@@ -54,7 +53,7 @@ class OptimumLM(HFLM): ...@@ -54,7 +53,7 @@ class OptimumLM(HFLM):
"package `optimum` is not installed. Please install it via `pip install optimum[openvino]`" "package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
) )
else: else:
from optimum.intel.openvino import OVModelForCausalLM from optimum.intel.openvino import OVModelForCausalLM, OVModelForSeq2SeqLM
model_kwargs = kwargs if kwargs else {} model_kwargs = kwargs if kwargs else {}
if "ov_config" in model_kwargs: if "ov_config" in model_kwargs:
...@@ -76,17 +75,14 @@ class OptimumLM(HFLM): ...@@ -76,17 +75,14 @@ class OptimumLM(HFLM):
model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = ( model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = (
"PIPELINE_PARALLEL" "PIPELINE_PARALLEL"
) )
model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists():
export = False
else:
export = True
self._model = OVModelForCausalLM.from_pretrained( model_cls = (
OVModelForCausalLM if self.backend == "causal" else OVModelForSeq2SeqLM
)
self._model = model_cls.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
export=export,
device=self.openvino_device.upper(), device=self.openvino_device.upper(),
**model_kwargs, **model_kwargs,
) )
...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM): ...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") re_ords = Collator(requests, _collate_gen, group_by=None)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM): ...@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM):
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences # add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
kwargs | {"max_tokens": max_gen_toks, "stop": until}
) )
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation # perform batched generation
# cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 . # cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 .
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM): ...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM):
self, self,
requests: List[List[int]] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List[Dict], Dict, None] = None,
stop: Optional[List[str]] = None,
return_logprob: bool = False, return_logprob: bool = False,
top_logprobs_num: int = 1, top_logprobs_num: int = 1,
logprob_start_len: int = -1, logprob_start_len: int = -1,
**kwargs,
): ):
# check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html. # check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html.
if generate: if not generate:
kwargs = self.modify_gen_kwargs(kwargs) sampling_params = sampling_params if sampling_params else {}
sampling_params = { sampling_params.update(
"max_new_tokens": max_tokens, {
"stop": stop, "temperature": 0,
} "max_new_tokens": 1,
sampling_params.update(kwargs) }
else: )
sampling_params = { if not isinstance(sampling_params, List):
"temperature": 0, sampling_params = [sampling_params] * len(requests)
"max_new_tokens": 1,
}
sampling_params.update(kwargs)
# Refer to: https://docs.sglang.ai/backend/offline_engine_api.html # Refer to: https://docs.sglang.ai/backend/offline_engine_api.html
outputs = self.model.generate( outputs = self.model.generate(
input_ids=requests, input_ids=requests,
......
...@@ -52,10 +52,10 @@ eval_logger = logging.getLogger(__name__) ...@@ -52,10 +52,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: list["SamplingParams"],
requests: list[list[int]], requests: list[list[int]],
lora_request: LoRARequest, lora_request: "LoRARequest",
result_queue: Queue, result_queue: "Queue",
dp_size: int, dp_size: int,
local_dp_rank: int, local_dp_rank: int,
dp_master_port: int, dp_master_port: int,
...@@ -197,6 +197,12 @@ class VLLM(TemplateLM): ...@@ -197,6 +197,12 @@ class VLLM(TemplateLM):
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
if "gemma" in pretrained.lower():
add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
from transformers import AutoConfig from transformers import AutoConfig
self._config = AutoConfig.from_pretrained( self._config = AutoConfig.from_pretrained(
...@@ -215,11 +221,6 @@ class VLLM(TemplateLM): ...@@ -215,11 +221,6 @@ class VLLM(TemplateLM):
"enable_thinking", enable_thinking "enable_thinking", enable_thinking
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = { kwargs_resolve_hf_chat_template = {
...@@ -365,17 +366,14 @@ class VLLM(TemplateLM): ...@@ -365,17 +366,14 @@ class VLLM(TemplateLM):
self, self,
requests: list[list[int]] = None, requests: list[list[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: list["SamplingParams"] | "SamplingParams" | None = None,
stop: list[str] | None = None,
**kwargs,
): ):
if generate: if not generate or sampling_params is None:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn # also seems to only work with decorator and not with ray.remote() fn
...@@ -383,9 +381,9 @@ class VLLM(TemplateLM): ...@@ -383,9 +381,9 @@ class VLLM(TemplateLM):
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: list["SamplingParams"],
requests: list[list[int]], requests: list[list[int]],
lora_request: LoRARequest, lora_request: "LoRARequest",
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
...@@ -397,9 +395,12 @@ class VLLM(TemplateLM): ...@@ -397,9 +395,12 @@ class VLLM(TemplateLM):
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
sampling_params = [
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
]
inputs = ( inputs = (
(self.model_args, sampling_params, req, self.lora_request) (self.model_args, sp, req, self.lora_request)
for req in requests for req, sp in zip(requests, sampling_params)
) )
object_refs = [run_inference_one_model.remote(*x) for x in inputs] object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs) results = ray.get(object_refs)
...@@ -414,16 +415,18 @@ class VLLM(TemplateLM): ...@@ -414,16 +415,18 @@ class VLLM(TemplateLM):
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port() dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests)) requests = (list(x) for x in distribute(self.data_parallel_size, requests))
sampling_params = (
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue() procs, resq = [], Queue()
# We use Process as it is non-daemonic # We use Process as it is non-daemonic
try: try:
for rank, req in enumerate(requests): for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
proc = Process( proc = Process(
target=_vllm_mp_worker, target=_vllm_mp_worker,
args=( args=(
self.model_args.copy(), self.model_args.copy(),
sampling_params, sp,
req, req,
self.lora_request, self.lora_request,
resq, resq,
...@@ -577,10 +580,11 @@ class VLLM(TemplateLM): ...@@ -577,10 +580,11 @@ class VLLM(TemplateLM):
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0] return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, re_ords = Collator(
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling requests,
# in the same batch. _collate_gen,
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") group_by=None,
)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -595,41 +599,44 @@ class VLLM(TemplateLM): ...@@ -595,41 +599,44 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences # add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs: if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding] if len(x) > max_ctx_len:
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning( eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context." f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
) )
context_encoding = [x[-max_ctx_len:] for x in context_encoding] context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
)
# perform batched generation # perform batched generation
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
...@@ -674,7 +681,7 @@ class VLLM(TemplateLM): ...@@ -674,7 +681,7 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for _cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
if ( if (
full_length := len(context_enc + continuation_enc) full_length := len(context_enc + continuation_enc)
) > self.max_length: ) > self.max_length:
......
This diff is collapsed.
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_1 - adr_prompt_1
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_2 - adr_prompt_2
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_3 - adr_prompt_3
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_4 - adr_prompt_4
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_5 - adr_prompt_5
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -4,7 +4,6 @@ tag: ...@@ -4,7 +4,6 @@ tag:
task: null task: null
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisent_prompt_2 - afrisent_prompt_2
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_3 - afrisenti_prompt_3
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_4 - afrisenti_prompt_4
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_5 - afrisenti_prompt_5
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
def doc_to_text(doc): def doc_to_text(doc):
output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in output = """Please provide the POS tags for each word in the input sentence. The input will be a list of words in
the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text the sentence. The output format should be a list of tuples, where each tuple consists of a word from the input text
and its corresponding POS tag label from the tag label set: ["ADJ", "ADP", "ADV", "AUX", "CCONJ, "DET", "INTJ", and its corresponding POS tag label from the tag label set: ["ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ",
"NOUN", "NUM", "PART", "PRON", "PROPN", "PUNCT" "SCONJ", "SYM", "VERB", "X"]. \nYour response should include only a "NOUN", "NUM", "PART", "PRON", "PROPN", "PUNCT" "SCONJ", "SYM", "VERB", "X"]. \nYour response should include only a
list of tuples, in the order that the words appear in the input sentence, with each tuple containing the list of tuples, in the order that the words appear in the input sentence, with each tuple containing the
corresponding POS tag label for a word. corresponding POS tag label for a word.
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- afrobench_sentiment_tasks - afrobench_sentiment_tasks
- nollysenti_prompt_1 - nollysenti_prompt_1
dataset_path: Davlan/nollysenti dataset_path: Davlan/nollysenti
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment