Commit 79b31dad authored by Baber's avatar Baber
Browse files

Merge branch 'bos' into mrl

parents cbb8f5a4 7e5f909b
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_academic_single
dataset_name: academic_single
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_history_tasks
task: longbench2_agent_history
dataset_name: agent_history_qa
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_detective
dataset_name: detective
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_history_tasks
task: longbench2_dialogue_history
dataset_name: dialogue_history_qa
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_event_order
dataset_name: event_ordering
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_multi_tasks
task: longbench2_fin_multi
dataset_name: financial_multi
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_fin_single
dataset_name: financial_single
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_multi_tasks
task: longbench2_govt_multi
dataset_name: government_multi
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_govt_single
dataset_name: government_single
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_structured_tasks
task: longbench2_graph
dataset_name: graph_reasoning
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_multi_tasks
task: longbench2_legal_multi
dataset_name: legal_multi
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_legal_single
dataset_name: legal_single
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_single_tasks
task: longbench2_lit_single
dataset_name: literary
include: _longbench_common_yaml
tag:
- longbench2_tasks
task: longbench2_code
dataset_name: code_repo_qa
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_incontext_tasks
task: longbench2_many_shot
dataset_name: manyshot_learning
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_multi_tasks
task: longbench2_news_multi
dataset_name: multinews
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_structured_tasks
task: longbench2_table
dataset_name: table_qa
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_incontext_tasks
task: longbench2_translate
dataset_name: new_language_translation
include: _longbench_common_yaml
tag:
- longbench2_tasks
- longbench2_incontext_tasks
task: longbench2_user_guide
dataset_name: user_guide_qa
......@@ -8,12 +8,14 @@ import json
import logging
import os
import re
import threading
from dataclasses import asdict, is_dataclass
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
import numpy as np
import requests
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined
......@@ -623,3 +625,218 @@ def hash_dict_images(data_dict):
if importlib.util.find_spec("PIL")
else data_dict
)
class RemoteTokenizer:
"""
Minimal robust tokenizer that uses vLLM server's tokenizer endpoints.
"""
def __init__(
self,
base_url: str,
timeout: int = 30,
verify_certificate: bool = True,
ca_cert_path: Optional[str] = None,
auth_token: Optional[str] = None,
max_retries: int = 3,
):
self.timeout = timeout
self.max_retries = max_retries
self._lock = threading.RLock()
self._tokenizer_info = None
self._chat_template_obj = None
# Certificate logic
self.cert_config = (
ca_cert_path if verify_certificate and ca_cert_path else verify_certificate
)
# Auth header logic
self.headers = {"Content-Type": "application/json"}
if auth_token:
self.headers["Authorization"] = f"Bearer {auth_token}"
# Normalize base URL - remove API endpoints to get server base
self.base_url = (
base_url.replace("/v1/completions", "")
.replace("/v1/chat/completions", "")
.rstrip("/")
)
# Use a session for connection pooling
self.session = requests.Session()
self.session.headers.update(self.headers)
# Validate server supports tokenizer_info endpoint
self._validate_server()
def _request_with_retries(self, method, url, **kwargs):
last_exc = None
for _ in range(self.max_retries):
try:
resp = self.session.request(
method,
url,
timeout=kwargs.pop("timeout", self.timeout),
verify=self.cert_config,
**kwargs,
)
resp.raise_for_status()
return resp
except requests.RequestException as e:
last_exc = e
raise RuntimeError(
f"RemoteTokenizer: {method} {url} failed after {self.max_retries} attempts: {last_exc}"
)
def _validate_server(self):
url = f"{self.base_url}/tokenizer_info"
resp = self._request_with_retries("GET", url)
if resp.status_code != 200:
raise RuntimeError(
f"Server does not support tokenizer_info endpoint. Status: {resp.status_code}"
)
@property
def tokenizer_info(self) -> dict:
with self._lock:
if self._tokenizer_info is None:
url = f"{self.base_url}/tokenizer_info"
resp = self._request_with_retries("GET", url)
self._tokenizer_info = resp.json()
return self._tokenizer_info
@property
def eos_token(self) -> Optional[str]:
return self.tokenizer_info.get("eos_token")
@property
def bos_token(self) -> Optional[str]:
return self.tokenizer_info.get("bos_token")
@property
def pad_token(self) -> Optional[str]:
return self.tokenizer_info.get("pad_token")
@property
def eos_token_id(self) -> Optional[int]:
if self.eos_token is None:
return None
return self.encode(self.eos_token)[0]
@property
def bos_token_id(self) -> Optional[int]:
if self.bos_token is None:
return None
return self.encode(self.bos_token)[0]
@property
def eot_token(self) -> Optional[int]:
return self.eos_token_id
def encode(self, text: str) -> List[int]:
url = f"{self.base_url}/tokenize"
payload = {"prompt": text, "add_special_tokens": False}
resp = self._request_with_retries("POST", url, json=payload)
tokens = resp.json().get("tokens")
if not isinstance(tokens, list):
raise RuntimeError("Malformed response from /tokenize endpoint.")
return tokens
def decode(self, tokens: List[int]) -> str:
url = f"{self.base_url}/detokenize"
payload = {"tokens": tokens}
resp = self._request_with_retries("POST", url, json=payload)
prompt = resp.json().get("prompt")
if not isinstance(prompt, str):
raise RuntimeError("Malformed response from /detokenize endpoint.")
return prompt
def batch_decode(self, tokens_list: List[List[int]]) -> List[str]:
return [self.decode(tokens) for tokens in tokens_list]
def apply_chat_template(
self, chat_history: list, add_generation_prompt: bool = True, **kwargs
) -> str:
with self._lock:
if self._chat_template_obj is None:
template_str = self.tokenizer_info.get("chat_template")
if not template_str:
raise ValueError("No chat template available from server")
self._chat_template_obj = env.from_string(template_str)
return self._chat_template_obj.render(
messages=chat_history, add_generation_prompt=add_generation_prompt, **kwargs
)
def __call__(self, text: str, add_special_tokens: bool = False, **kwargs) -> dict:
tokens = self.encode(text)
return {"input_ids": tokens}
def check_remote_tokenizer_support(
base_url: str,
timeout: int = 5,
verify_certificate: bool = True,
ca_cert_path: Optional[str] = None,
auth_token: Optional[str] = None,
max_retries: int = 3,
) -> bool:
"""
Check if server supports remote tokenizer endpoints.
Returns True if both /tokenizer_info and /tokenize endpoints are available and functional, False otherwise.
"""
if not base_url:
return False
server_base = (
base_url.replace("/v1/completions", "")
.replace("/v1/chat/completions", "")
.rstrip("/")
)
cert_config = (
ca_cert_path if verify_certificate and ca_cert_path else verify_certificate
)
headers = {"Content-Type": "application/json"}
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
session = requests.Session()
session.headers.update(headers)
def _request_with_retries(method, url, **kwargs):
for _ in range(max_retries):
try:
resp = session.request(
method,
url,
timeout=kwargs.pop("timeout", timeout),
verify=cert_config,
**kwargs,
)
resp.raise_for_status()
return resp
except requests.RequestException:
pass
return None
# Check /tokenizer_info
info_url = f"{server_base}/tokenizer_info"
resp = _request_with_retries("GET", info_url)
if not resp:
return False
info = resp.json()
if not isinstance(info, dict) or "eos_token" not in info:
return False
# Check /tokenize
tokenize_url = f"{server_base}/tokenize"
test_payload = {"prompt": "test", "add_special_tokens": False}
resp = _request_with_retries("POST", tokenize_url, json=test_payload)
if not resp:
return False
tokens = resp.json().get("tokens")
if not isinstance(tokens, list):
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment