Commit 03f0e80e authored by Baber's avatar Baber
Browse files

fix bos token handling

parent d701d50f
......@@ -324,6 +324,7 @@ class TemplateLM(LM):
"""
tokenizer = None
backend = "causal"
@property
@abc.abstractmethod
......@@ -378,24 +379,22 @@ class TemplateLM(LM):
handle empty context (see loglikelihood method).
"""
assert context, "Context cannot be empty!"
import transformers
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
if model_class == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
else:
if self.backend == "causal":
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
else:
# for SEQ2SEQ case we need to encode separately
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
return context_enc, continuation_enc
......@@ -433,7 +432,7 @@ class TemplateLM(LM):
continuation_enc = self.tok_encode(
continuation, add_special_tokens=False
)
# BOS or EOS as context
# BOS or EOS as context: handle when context is empty -> (context + continuation) -> (BOS + continuation
context_enc, continuation_enc = (
([self.prefix_token_id], continuation_enc)
if self.prefix_token_id != continuation_enc[0]
......
......@@ -258,7 +258,7 @@ class HFLM(TemplateLM):
else {}
)
self.add_bos_token = add_bos_token if add_bos_token is not None else None
self.add_bos_token = add_bos_token
self._max_length = max_length
self.pretrained = pretrained
......
from __future__ import annotations
import copy
import gc
import logging
......@@ -7,7 +9,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue
from queue import Empty
from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Literal
import jinja2
from more_itertools import distribute
......@@ -50,10 +52,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: list["SamplingParams"],
sampling_params: list[SamplingParams],
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
lora_request: LoRARequest,
result_queue: Queue,
dp_size: int,
local_dp_rank: int,
dp_master_port: int,
......@@ -113,18 +115,18 @@ class VLLM(TemplateLM):
self,
pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
tokenizer: Optional[str] = None,
revision: str | None = None,
trust_remote_code: bool | None = False,
tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None,
add_bos_token: Optional[bool] = False,
prefix_token_id: Optional[int] = None,
tokenizer_revision: str | None = None,
add_bos_token: bool | None = False,
prefix_token_id: int | None = None,
tensor_parallel_size: int = 1,
quantization: Optional[str] = None,
quantization: str | None = None,
max_gen_toks: int = 256,
swap_space: int = 4,
batch_size: Union[str, int] = 1,
batch_size: str | int = 1,
max_batch_size=None,
max_length: int = None,
max_model_len: int = None,
......@@ -134,9 +136,9 @@ class VLLM(TemplateLM):
lora_local_path: str = None,
# VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
chat_template_args: Optional[dict] = None,
chat_template_args: dict | None = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
think_end_token: str | None = None,
max_lora_rank: int = 16,
**kwargs,
):
......@@ -195,11 +197,7 @@ class VLLM(TemplateLM):
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
if "gemma" in pretrained.lower():
add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
self.add_bos_token = add_bos_token
from transformers import AutoConfig
......@@ -211,14 +209,17 @@ class VLLM(TemplateLM):
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
revision=tokenizer_revision,
add_bos_token=add_bos_token,
**(
{"add_bos_token": self.add_bos_token}
if self.add_bos_token is not None
else {}
),
)
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.chat_template_args = chat_template_args or {}
self.enable_thinking = self.chat_template_args.pop(
"enable_thinking", enable_thinking
)
self.add_bos_token = add_bos_token
if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = {
......@@ -265,7 +266,7 @@ class VLLM(TemplateLM):
self.lora_request = None
@property
def eot_token_id(self):
def eot_token_id(self) -> int | None:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
......@@ -300,7 +301,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
"""
Method to apply a chat template to a list of chat history between user and model.
......@@ -337,18 +338,26 @@ class VLLM(TemplateLM):
def tok_encode(
self,
string: Union[str, List[str]],
left_truncate_len: int = None,
add_special_tokens: bool = False,
string: str | list[str],
left_truncate_len: int | None = None,
add_special_tokens: bool | None = None,
truncation: bool = False,
) -> Union[List[int], List[List[int]]]:
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
) -> list[int] | list[list[int]]:
add_special_kwargs = (
{"add_special_tokens": add_special_tokens or self.add_bos_token}
if (add_special_tokens is not None or self.add_bos_token is not None)
else {}
)
# handle chat template
if self.tokenizer.bos_token and (
string[0] if isinstance(string, list) else string
).startswith(self.tokenizer.bos_token):
add_special_kwargs = {"add_special_tokens": False}
encoding: list[list[int]] | list[int] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
return_attention_mask=False,
**add_special_kwargs,
).input_ids
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
......@@ -362,15 +371,15 @@ class VLLM(TemplateLM):
def _model_generate(
self,
requests: List[List[int]] = None,
requests: list[list[int]],
generate: bool = False,
sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None,
sampling_params: list[SamplingParams] | SamplingParams | None = None,
):
if not generate or sampling_params is None:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if not isinstance(sampling_params, List):
if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote
......@@ -379,9 +388,9 @@ class VLLM(TemplateLM):
@ray.remote
def run_inference_one_model(
model_args: dict,
sampling_params: List["SamplingParams"],
requests: List[List[int]],
lora_request: "LoRARequest",
sampling_params: list[SamplingParams],
requests: list[list[int]],
lora_request: LoRARequest,
):
llm = LLM(**model_args)
return llm.generate(
......@@ -487,8 +496,8 @@ class VLLM(TemplateLM):
return outputs
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
adaptive_batch_size = None
if self.batch_size == "auto":
adaptive_batch_size = len(requests)
......@@ -503,7 +512,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map(
make_disjoint_window,
get_rolling_token_windows(
......@@ -556,16 +565,14 @@ class VLLM(TemplateLM):
return loglikelihoods
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
# batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode(
context, add_special_tokens=self.add_bos_token
)
requests = [
context_encoding = self.tok_encode(context)
reqs = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
]
......@@ -579,7 +586,7 @@ class VLLM(TemplateLM):
return -len(_requests[0][1]), _requests[0][0]
re_ords = Collator(
requests,
reqs,
_collate_gen,
group_by=None,
)
......@@ -588,7 +595,7 @@ class VLLM(TemplateLM):
)
pbar = tqdm(
total=len(requests),
total=len(reqs),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
......@@ -656,9 +663,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens(
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
res = []
def _collate(x):
......@@ -717,7 +724,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment