Commit 3e8135ce authored by Baber's avatar Baber
Browse files

Merge branch 'main' into comma

parents 8e560c96 0c134ee9
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import os import os
__version__ = "0.4.9" __version__ = "0.4.9.1"
# Lazy-load .evaluator module to improve CLI startup # Lazy-load .evaluator module to improve CLI startup
......
...@@ -5,8 +5,9 @@ import traceback ...@@ -5,8 +5,9 @@ import traceback
from typing import Iterator, List, Sequence, Tuple, TypeVar from typing import Iterator, List, Sequence, Tuple, TypeVar
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module.
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
try: try:
import janitor_util import janitor_util
......
...@@ -71,13 +71,6 @@ class SteeredModel(HFLM): ...@@ -71,13 +71,6 @@ class SteeredModel(HFLM):
""" """
HFLM with a steered forward pass. HFLM with a steered forward pass.
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,layer_20/width_16k/canonical,increase dogs,
To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format: To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format:
{ {
...@@ -86,9 +79,17 @@ class SteeredModel(HFLM): ...@@ -86,9 +79,17 @@ class SteeredModel(HFLM):
"steering_coefficient": <float>, "steering_coefficient": <float>,
"action": <Literal["add", "clamp"]>, "action": <Literal["add", "clamp"]>,
"bias": <torch.Tensor | None>, "bias": <torch.Tensor | None>,
"head_index": <int | None>,
}, },
... ...
} }
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,head_index,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,,layer_20/width_16k/canonical,increase dogs,
""" """
super().__init__(pretrained=pretrained, device=device, **kwargs) super().__init__(pretrained=pretrained, device=device, **kwargs)
...@@ -105,27 +106,31 @@ class SteeredModel(HFLM): ...@@ -105,27 +106,31 @@ class SteeredModel(HFLM):
hook_to_steer = {} hook_to_steer = {}
for hookpoint, steer_info in steer_config.items(): for hookpoint, steer_info in steer_config.items():
action = steer_info["action"] action = steer_info["action"]
steering_coefficient = steer_info["steering_coefficient"]
steering_vector = ( steering_vector = (
steer_info["steering_vector"].to(self.device).to(self.model.dtype) steer_info["steering_vector"].to(self.device).to(self.model.dtype)
) )
bias = ( steering_coefficient = float(steer_info.get("steering_coefficient", 1.0))
steer_info["bias"].to(self.device).to(self.model.dtype) head_index = steer_info.get("head_index", None)
if steer_info["bias"] is not None bias = steer_info.get("bias", None)
else None if bias is not None:
) bias = bias.to(self.device).to(self.model.dtype)
if action == "add": if action == "add":
# Steers the model by adding some multiple of a steering vector to all sequence positions. # Steer the model by adding a multiple of a steering vector to all sequence positions.
hook_to_steer[hookpoint] = ( assert bias is None, "Bias is not supported for the `add` action."
lambda acts: acts + steering_coefficient * steering_vector hook_to_steer[hookpoint] = partial(
self.add,
vector=steering_vector * steering_coefficient,
head_index=head_index,
) )
elif action == "clamp": elif action == "clamp":
# Steer the model by clamping the activations to a value in the direction of the steering vector.
hook_to_steer[hookpoint] = partial( hook_to_steer[hookpoint] = partial(
self.clamp, self.clamp,
steering_vector=steering_vector, direction=steering_vector / torch.norm(steering_vector),
value=steering_coefficient, value=steering_coefficient,
bias=bias, bias=bias,
head_index=head_index,
) )
else: else:
raise ValueError(f"Unknown hook type: {action}") raise ValueError(f"Unknown hook type: {action}")
...@@ -195,34 +200,62 @@ class SteeredModel(HFLM): ...@@ -195,34 +200,62 @@ class SteeredModel(HFLM):
return steer_data return steer_data
@classmethod
def add(
cls,
acts: Tensor,
vector: Tensor,
head_index: Optional[int],
):
"""Adds the given vector to the activations.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
vector (Tensor): A vector to add of shape [features]
head_index (int | None): Optional attention head index to add to
"""
if head_index is not None:
acts[:, :, head_index, :] = acts[:, :, head_index, :] + vector
else:
acts = acts + vector
return acts
@classmethod @classmethod
def clamp( def clamp(
cls, cls,
acts: Tensor, acts: Tensor,
steering_vector: Tensor, direction: Tensor,
value: float, value: float,
head_index: Optional[int],
bias: Optional[Tensor] = None, bias: Optional[Tensor] = None,
): ):
"""Clamps a direction of the activations to be the steering vector * the value. """Clamps the activations to a given value in a specified direction. The direction
must be a unit vector.
Args: Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, features] acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
steering_vector (Tensor): A direction to clamp of shape [features] direction (Tensor): A direction to clamp of shape [features]
value (float): Value to clamp the direction to value (float): Value to clamp the direction to
head_index (int | None): Optional attention head index to clamp
bias (Tensor | None): Optional bias to add to the activations bias (Tensor | None): Optional bias to add to the activations
Returns: Returns:
Tensor: The modified activations with the specified direction clamped Tensor: The modified activations with the specified direction clamped
""" """
if bias is not None: if bias is not None:
acts = acts - bias acts = acts - bias
direction = steering_vector / torch.norm(steering_vector) if head_index is not None:
proj_magnitude = torch.sum(acts * direction, dim=-1, keepdim=True) x = acts[:, :, head_index, :]
orthogonal_component = acts - proj_magnitude * direction proj = (x * direction).sum(dim=-1, keepdim=True)
assert proj == acts @ direction
clamped = orthogonal_component + direction * value clamped = acts.clone()
clamped[:, :, head_index, :] = x + direction * (value - proj)
else:
proj = torch.sum(acts * direction, dim=-1, keepdim=True)
clamped = acts + direction * (value - proj)
if bias is not None: if bias is not None:
return clamped + bias return clamped + bias
......
...@@ -680,10 +680,19 @@ class HFLM(TemplateLM): ...@@ -680,10 +680,19 @@ class HFLM(TemplateLM):
"0.4.0" "0.4.0"
): ):
raise AssertionError("load_in_4bit requires peft >= 0.4.0") raise AssertionError("load_in_4bit requires peft >= 0.4.0")
if self._model.config.vocab_size != len(self.tokenizer):
# Compatible with Gemma3 (multimodal) and old models
if hasattr(self._model.config, "text_config") and hasattr(
self._model.config.text_config, "vocab_size"
):
vocab_size = self._model.config.text_config.vocab_size
else:
vocab_size = self._model.config.vocab_size
if vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens # resize model for LoRAs with added tokens
eval_logger.info( eval_logger.info(
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." f"Model config indicates vocab_size='{vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
) )
self._model.resize_token_embeddings(len(self.tokenizer)) self._model.resize_token_embeddings(len(self.tokenizer))
self._model = PeftModel.from_pretrained( self._model = PeftModel.from_pretrained(
......
...@@ -289,7 +289,7 @@ class OpenAIChatCompletion(LocalChatCompletion): ...@@ -289,7 +289,7 @@ class OpenAIChatCompletion(LocalChatCompletion):
"seed": seed, "seed": seed,
**gen_kwargs, **gen_kwargs,
} }
if "o1" in self.model: if "o1" in self.model or "5" in self.model:
output.pop("stop") output.pop("stop")
output["temperature"] = 1 output["temperature"] = 1
elif "o3" in self.model: elif "o3" in self.model:
......
...@@ -28,9 +28,8 @@ class OptimumLM(HFLM): ...@@ -28,9 +28,8 @@ class OptimumLM(HFLM):
**kwargs, **kwargs,
) -> None: ) -> None:
if "backend" in kwargs: if "backend" in kwargs:
# optimum currently only supports causal models assert kwargs["backend"] in ["causal", "seq2seq"], (
assert kwargs["backend"] == "causal", ( "Currently, only OVModelForCausalLM or OVModelForSeq2SeqLM are supported."
"Currently, only OVModelForCausalLM is supported."
) )
self.openvino_device = device self.openvino_device = device
...@@ -54,7 +53,7 @@ class OptimumLM(HFLM): ...@@ -54,7 +53,7 @@ class OptimumLM(HFLM):
"package `optimum` is not installed. Please install it via `pip install optimum[openvino]`" "package `optimum` is not installed. Please install it via `pip install optimum[openvino]`"
) )
else: else:
from optimum.intel.openvino import OVModelForCausalLM from optimum.intel.openvino import OVModelForCausalLM, OVModelForSeq2SeqLM
model_kwargs = kwargs if kwargs else {} model_kwargs = kwargs if kwargs else {}
if "ov_config" in model_kwargs: if "ov_config" in model_kwargs:
...@@ -76,17 +75,14 @@ class OptimumLM(HFLM): ...@@ -76,17 +75,14 @@ class OptimumLM(HFLM):
model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = ( model_kwargs["ov_config"]["MODEL_DISTRIBUTION_POLICY"] = (
"PIPELINE_PARALLEL" "PIPELINE_PARALLEL"
) )
model_file = Path(pretrained) / "openvino_model.xml"
if model_file.exists():
export = False
else:
export = True
self._model = OVModelForCausalLM.from_pretrained( model_cls = (
OVModelForCausalLM if self.backend == "causal" else OVModelForSeq2SeqLM
)
self._model = model_cls.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
export=export,
device=self.openvino_device.upper(), device=self.openvino_device.upper(),
**model_kwargs, **model_kwargs,
) )
...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM): ...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") re_ords = Collator(requests, _collate_gen, group_by=None)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -232,9 +232,9 @@ class SGLangLM(TemplateLM): ...@@ -232,9 +232,9 @@ class SGLangLM(TemplateLM):
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
...@@ -252,16 +252,21 @@ class SGLangLM(TemplateLM): ...@@ -252,16 +252,21 @@ class SGLangLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding] if len(x) > max_ctx_len:
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
kwargs | {"max_tokens": max_gen_toks, "stop": until}
)
# perform batched generation # perform batched generation
# cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 . # cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 .
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM): ...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM):
self, self,
requests: List[List[int]] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List[Dict], Dict, None] = None,
stop: Optional[List[str]] = None,
return_logprob: bool = False, return_logprob: bool = False,
top_logprobs_num: int = 1, top_logprobs_num: int = 1,
logprob_start_len: int = -1, logprob_start_len: int = -1,
**kwargs,
): ):
# check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html. # check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html.
if generate: if not generate:
kwargs = self.modify_gen_kwargs(kwargs) sampling_params = sampling_params if sampling_params else {}
sampling_params = { sampling_params.update(
"max_new_tokens": max_tokens, {
"stop": stop,
}
sampling_params.update(kwargs)
else:
sampling_params = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 1, "max_new_tokens": 1,
} }
sampling_params.update(kwargs) )
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
# Refer to: https://docs.sglang.ai/backend/offline_engine_api.html # Refer to: https://docs.sglang.ai/backend/offline_engine_api.html
outputs = self.model.generate( outputs = self.model.generate(
input_ids=requests, input_ids=requests,
......
import copy import copy
import gc import gc
import inspect
import logging import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
...@@ -33,7 +32,7 @@ from lm_eval.utils import ( ...@@ -33,7 +32,7 @@ from lm_eval.utils import (
try: try:
import ray import ray
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port from vllm.utils import get_open_port
...@@ -51,7 +50,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -51,7 +50,7 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: "SamplingParams", sampling_params: list["SamplingParams"],
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: "LoRARequest",
result_queue: "Queue", result_queue: "Queue",
...@@ -79,7 +78,7 @@ def _vllm_mp_worker( ...@@ -79,7 +78,7 @@ def _vllm_mp_worker(
try: try:
llm = LLM(**model_args) llm = LLM(**model_args)
res = llm.generate( res = llm.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -196,6 +195,12 @@ class VLLM(TemplateLM): ...@@ -196,6 +195,12 @@ class VLLM(TemplateLM):
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
if "gemma" in pretrained.lower():
add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
from transformers import AutoConfig from transformers import AutoConfig
self._config = AutoConfig.from_pretrained( self._config = AutoConfig.from_pretrained(
...@@ -214,11 +219,6 @@ class VLLM(TemplateLM): ...@@ -214,11 +219,6 @@ class VLLM(TemplateLM):
"enable_thinking", enable_thinking "enable_thinking", enable_thinking
) )
self.add_bos_token = add_bos_token self.add_bos_token = add_bos_token
if "gemma" in pretrained.lower():
self.add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = { kwargs_resolve_hf_chat_template = {
...@@ -239,13 +239,6 @@ class VLLM(TemplateLM): ...@@ -239,13 +239,6 @@ class VLLM(TemplateLM):
model_config = engine_args.create_model_config() model_config = engine_args.create_model_config()
kwargs_resolve_hf_chat_template["model_config"] = model_config kwargs_resolve_hf_chat_template["model_config"] = model_config
# https://github.com/vllm-project/vllm/pull/18259
if (
"trsut_remote_code"
in inspect.signature(resolve_hf_chat_template).parameters
):
kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
else: else:
kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code
...@@ -371,17 +364,14 @@ class VLLM(TemplateLM): ...@@ -371,17 +364,14 @@ class VLLM(TemplateLM):
self, self,
requests: List[List[int]] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None,
stop: Optional[List[str]] = None,
**kwargs,
): ):
if generate: if not generate or sampling_params is None:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn # also seems to only work with decorator and not with ray.remote() fn
...@@ -389,13 +379,13 @@ class VLLM(TemplateLM): ...@@ -389,13 +379,13 @@ class VLLM(TemplateLM):
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: List["SamplingParams"],
requests: List[List[int]], requests: List[List[int]],
lora_request: LoRARequest, lora_request: "LoRARequest",
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -403,9 +393,12 @@ class VLLM(TemplateLM): ...@@ -403,9 +393,12 @@ class VLLM(TemplateLM):
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
sampling_params = [
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
]
inputs = ( inputs = (
(self.model_args, sampling_params, req, self.lora_request) (self.model_args, sp, req, self.lora_request)
for req in requests for req, sp in zip(requests, sampling_params)
) )
object_refs = [run_inference_one_model.remote(*x) for x in inputs] object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs) results = ray.get(object_refs)
...@@ -420,16 +413,18 @@ class VLLM(TemplateLM): ...@@ -420,16 +413,18 @@ class VLLM(TemplateLM):
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port() dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests)) requests = (list(x) for x in distribute(self.data_parallel_size, requests))
sampling_params = (
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue() procs, resq = [], Queue()
# We use Process as it is non-daemonic # We use Process as it is non-daemonic
try: try:
for rank, req in enumerate(requests): for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
proc = Process( proc = Process(
target=_vllm_mp_worker, target=_vllm_mp_worker,
args=( args=(
self.model_args.copy(), self.model_args.copy(),
sampling_params, sp,
req, req,
self.lora_request, self.lora_request,
resq, resq,
...@@ -484,7 +479,7 @@ class VLLM(TemplateLM): ...@@ -484,7 +479,7 @@ class VLLM(TemplateLM):
else: else:
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request, lora_request=self.lora_request,
...@@ -583,10 +578,11 @@ class VLLM(TemplateLM): ...@@ -583,10 +578,11 @@ class VLLM(TemplateLM):
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0] return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, re_ords = Collator(
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling requests,
# in the same batch. _collate_gen,
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") group_by=None,
)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -601,9 +597,9 @@ class VLLM(TemplateLM): ...@@ -601,9 +597,9 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
...@@ -621,21 +617,24 @@ class VLLM(TemplateLM): ...@@ -621,21 +617,24 @@ class VLLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding] if len(x) > max_ctx_len:
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning( eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context." f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
) )
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation # perform batched generation
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
......
This diff is collapsed.
...@@ -81,7 +81,7 @@ class TaskManager: ...@@ -81,7 +81,7 @@ class TaskManager:
task_index = {} task_index = {}
for task_dir in all_paths: for task_dir in all_paths:
tasks = self._get_task_and_group(task_dir) tasks = self._get_task_and_group(task_dir)
task_index = {**tasks, **task_index} task_index = {**task_index, **tasks}
return task_index return task_index
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_1 - adr_prompt_1
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_2 - adr_prompt_2
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_3 - adr_prompt_3
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_4 - adr_prompt_4
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -2,7 +2,6 @@ tag: ...@@ -2,7 +2,6 @@ tag:
- adr_tasks - adr_tasks
- adr_prompt_5 - adr_prompt_5
dataset_path: masakhane/diacritics-restoration dataset_path: masakhane/diacritics-restoration
dataset_kwargs: {trust_remote_code: True}
doc_to_target: target doc_to_target: target
output_type: generate_until output_type: generate_until
fewshot_split: dev fewshot_split: dev
......
...@@ -4,7 +4,6 @@ tag: ...@@ -4,7 +4,6 @@ tag:
task: null task: null
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisent_prompt_2 - afrisent_prompt_2
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_3 - afrisenti_prompt_3
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_4 - afrisenti_prompt_4
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
...@@ -3,7 +3,6 @@ tag: ...@@ -3,7 +3,6 @@ tag:
- afrisenti_prompt_5 - afrisenti_prompt_5
dataset_path: masakhane/afrisenti dataset_path: masakhane/afrisenti
dataset_name: null dataset_name: null
dataset_kwargs: {trust_remote_code: True}
output_type: multiple_choice output_type: multiple_choice
validation_split: validation validation_split: validation
test_split: test test_split: test
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment