Commit 03f0e80e authored by Baber's avatar Baber
Browse files

fix bos token handling

parent d701d50f
...@@ -324,6 +324,7 @@ class TemplateLM(LM): ...@@ -324,6 +324,7 @@ class TemplateLM(LM):
""" """
tokenizer = None tokenizer = None
backend = "causal"
@property @property
@abc.abstractmethod @abc.abstractmethod
...@@ -378,24 +379,22 @@ class TemplateLM(LM): ...@@ -378,24 +379,22 @@ class TemplateLM(LM):
handle empty context (see loglikelihood method). handle empty context (see loglikelihood method).
""" """
assert context, "Context cannot be empty!" assert context, "Context cannot be empty!"
import transformers
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces] context = context[:-n_spaces]
model_class = getattr(self, "AUTO_MODEL_CLASS", None) if self.backend == "causal":
if model_class == transformers.AutoModelForSeq2SeqLM:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
else:
whole_enc = self.tok_encode(context + continuation) whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
else:
# for SEQ2SEQ case we need to encode separately
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
return context_enc, continuation_enc return context_enc, continuation_enc
...@@ -433,7 +432,7 @@ class TemplateLM(LM): ...@@ -433,7 +432,7 @@ class TemplateLM(LM):
continuation_enc = self.tok_encode( continuation_enc = self.tok_encode(
continuation, add_special_tokens=False continuation, add_special_tokens=False
) )
# BOS or EOS as context # BOS or EOS as context: handle when context is empty -> (context + continuation) -> (BOS + continuation
context_enc, continuation_enc = ( context_enc, continuation_enc = (
([self.prefix_token_id], continuation_enc) ([self.prefix_token_id], continuation_enc)
if self.prefix_token_id != continuation_enc[0] if self.prefix_token_id != continuation_enc[0]
......
...@@ -258,7 +258,7 @@ class HFLM(TemplateLM): ...@@ -258,7 +258,7 @@ class HFLM(TemplateLM):
else {} else {}
) )
self.add_bos_token = add_bos_token if add_bos_token is not None else None self.add_bos_token = add_bos_token
self._max_length = max_length self._max_length = max_length
self.pretrained = pretrained self.pretrained = pretrained
......
from __future__ import annotations
import copy import copy
import gc import gc
import logging import logging
...@@ -7,7 +9,7 @@ from importlib.util import find_spec ...@@ -7,7 +9,7 @@ from importlib.util import find_spec
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from queue import Empty from queue import Empty
from time import sleep from time import sleep
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Literal
import jinja2 import jinja2
from more_itertools import distribute from more_itertools import distribute
...@@ -50,10 +52,10 @@ eval_logger = logging.getLogger(__name__) ...@@ -50,10 +52,10 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: list["SamplingParams"], sampling_params: list[SamplingParams],
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: LoRARequest,
result_queue: "Queue", result_queue: Queue,
dp_size: int, dp_size: int,
local_dp_rank: int, local_dp_rank: int,
dp_master_port: int, dp_master_port: int,
...@@ -113,18 +115,18 @@ class VLLM(TemplateLM): ...@@ -113,18 +115,18 @@ class VLLM(TemplateLM):
self, self,
pretrained: str, pretrained: str,
dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto",
revision: Optional[str] = None, revision: str | None = None,
trust_remote_code: Optional[bool] = False, trust_remote_code: bool | None = False,
tokenizer: Optional[str] = None, tokenizer: str | None = None,
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: str | None = None,
add_bos_token: Optional[bool] = False, add_bos_token: bool | None = False,
prefix_token_id: Optional[int] = None, prefix_token_id: int | None = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[str] = None, quantization: str | None = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: str | int = 1,
max_batch_size=None, max_batch_size=None,
max_length: int = None, max_length: int = None,
max_model_len: int = None, max_model_len: int = None,
...@@ -134,9 +136,9 @@ class VLLM(TemplateLM): ...@@ -134,9 +136,9 @@ class VLLM(TemplateLM):
lora_local_path: str = None, lora_local_path: str = None,
# VLLM: enable thinking tags in the prompt. # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True, enable_thinking: bool = True,
chat_template_args: Optional[dict] = None, chat_template_args: dict | None = None,
# End marker for thinking tags - splits to get response after this token (if provided). # End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None, think_end_token: str | None = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -195,11 +197,7 @@ class VLLM(TemplateLM): ...@@ -195,11 +197,7 @@ class VLLM(TemplateLM):
self.batch_size = "auto" self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.") eval_logger.info("Manual batching is not compatible with data parallelism.")
if "gemma" in pretrained.lower(): self.add_bos_token = add_bos_token
add_bos_token = True
eval_logger.info(
"Found 'gemma' in model name, a BOS token will be used as Gemma series models underperform without it."
)
from transformers import AutoConfig from transformers import AutoConfig
...@@ -211,14 +209,17 @@ class VLLM(TemplateLM): ...@@ -211,14 +209,17 @@ class VLLM(TemplateLM):
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=tokenizer_revision, revision=tokenizer_revision,
add_bos_token=add_bos_token, **(
{"add_bos_token": self.add_bos_token}
if self.add_bos_token is not None
else {}
),
) )
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config) self.tokenizer = configure_pad_token(self.tokenizer, model_config=self._config)
self.chat_template_args = chat_template_args or {} self.chat_template_args = chat_template_args or {}
self.enable_thinking = self.chat_template_args.pop( self.enable_thinking = self.chat_template_args.pop(
"enable_thinking", enable_thinking "enable_thinking", enable_thinking
) )
self.add_bos_token = add_bos_token
if parse_version(version("vllm")) >= parse_version("0.8.3"): if parse_version(version("vllm")) >= parse_version("0.8.3"):
kwargs_resolve_hf_chat_template = { kwargs_resolve_hf_chat_template = {
...@@ -265,7 +266,7 @@ class VLLM(TemplateLM): ...@@ -265,7 +266,7 @@ class VLLM(TemplateLM):
self.lora_request = None self.lora_request = None
@property @property
def eot_token_id(self): def eot_token_id(self) -> int | None:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
...@@ -300,7 +301,7 @@ class VLLM(TemplateLM): ...@@ -300,7 +301,7 @@ class VLLM(TemplateLM):
return self._max_gen_toks return self._max_gen_toks
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str: ) -> str:
""" """
Method to apply a chat template to a list of chat history between user and model. Method to apply a chat template to a list of chat history between user and model.
...@@ -337,18 +338,26 @@ class VLLM(TemplateLM): ...@@ -337,18 +338,26 @@ class VLLM(TemplateLM):
def tok_encode( def tok_encode(
self, self,
string: Union[str, List[str]], string: str | list[str],
left_truncate_len: int = None, left_truncate_len: int | None = None,
add_special_tokens: bool = False, add_special_tokens: bool | None = None,
truncation: bool = False, truncation: bool = False,
) -> Union[List[int], List[List[int]]]: ) -> list[int] | list[list[int]]:
if not add_special_tokens: add_special_kwargs = (
add_special_tokens = False or self.add_bos_token {"add_special_tokens": add_special_tokens or self.add_bos_token}
encoding: Union[List[List[int]], List[int]] = self.tokenizer( if (add_special_tokens is not None or self.add_bos_token is not None)
else {}
)
# handle chat template
if self.tokenizer.bos_token and (
string[0] if isinstance(string, list) else string
).startswith(self.tokenizer.bos_token):
add_special_kwargs = {"add_special_tokens": False}
encoding: list[list[int]] | list[int] = self.tokenizer(
string, string,
add_special_tokens=add_special_tokens,
truncation=truncation, truncation=truncation,
return_attention_mask=False, return_attention_mask=False,
**add_special_kwargs,
).input_ids ).input_ids
# left-truncate the encoded context to be at most `left_truncate_len` tokens long # left-truncate the encoded context to be at most `left_truncate_len` tokens long
...@@ -362,15 +371,15 @@ class VLLM(TemplateLM): ...@@ -362,15 +371,15 @@ class VLLM(TemplateLM):
def _model_generate( def _model_generate(
self, self,
requests: List[List[int]] = None, requests: list[list[int]],
generate: bool = False, generate: bool = False,
sampling_params: Union[List["SamplingParams"], "SamplingParams", None] = None, sampling_params: list[SamplingParams] | SamplingParams | None = None,
): ):
if not generate or sampling_params is None: if not generate or sampling_params is None:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if not isinstance(sampling_params, List): if not isinstance(sampling_params, list):
sampling_params = [sampling_params] * len(requests) sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
...@@ -379,9 +388,9 @@ class VLLM(TemplateLM): ...@@ -379,9 +388,9 @@ class VLLM(TemplateLM):
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: List["SamplingParams"], sampling_params: list[SamplingParams],
requests: List[List[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: LoRARequest,
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
...@@ -487,8 +496,8 @@ class VLLM(TemplateLM): ...@@ -487,8 +496,8 @@ class VLLM(TemplateLM):
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
adaptive_batch_size = len(requests) adaptive_batch_size = len(requests)
...@@ -503,7 +512,7 @@ class VLLM(TemplateLM): ...@@ -503,7 +512,7 @@ class VLLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
) )
): ):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list( rolling_token_windows: list[tuple[list[int], list[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -556,16 +565,14 @@ class VLLM(TemplateLM): ...@@ -556,16 +565,14 @@ class VLLM(TemplateLM):
return loglikelihoods return loglikelihoods
def generate_until( def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[str]: ) -> list[str]:
res = [] res = []
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding: List[List[int]] = self.tok_encode( context_encoding = self.tok_encode(context)
context, add_special_tokens=self.add_bos_token reqs = [
)
requests = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs) ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
] ]
...@@ -579,7 +586,7 @@ class VLLM(TemplateLM): ...@@ -579,7 +586,7 @@ class VLLM(TemplateLM):
return -len(_requests[0][1]), _requests[0][0] return -len(_requests[0][1]), _requests[0][0]
re_ords = Collator( re_ords = Collator(
requests, reqs,
_collate_gen, _collate_gen,
group_by=None, group_by=None,
) )
...@@ -588,7 +595,7 @@ class VLLM(TemplateLM): ...@@ -588,7 +595,7 @@ class VLLM(TemplateLM):
) )
pbar = tqdm( pbar = tqdm(
total=len(requests), total=len(reqs),
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests", desc="Running generate_until requests",
) )
...@@ -656,9 +663,9 @@ class VLLM(TemplateLM): ...@@ -656,9 +663,9 @@ class VLLM(TemplateLM):
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]], requests: list[tuple[tuple[str, str], list[int], list[int]]],
disable_tqdm: bool = False, disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
res = [] res = []
def _collate(x): def _collate(x):
...@@ -717,7 +724,7 @@ class VLLM(TemplateLM): ...@@ -717,7 +724,7 @@ class VLLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: def _parse_logprobs(tokens: list, outputs, ctxlen: int) -> tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment