Commit 3bc01ac1 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

[Minor] improve code style

parent 9f009261
......@@ -149,12 +149,12 @@ async def send_request(
"inputs": prompt,
"parameters": params,
}
elif backend == "xinfer":
elif backend == "ginfer":
pass
else:
raise ValueError(f"Unknown backend: {backend}")
if backend != "xinfer":
if backend != "ginfer":
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
......@@ -172,7 +172,7 @@ async def send_request(
print(output)
else:
import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc
from ginfer import sampler_pb2, sampler_pb2_grpc
api_url = api_url.replace("http://", "").replace("/generate", "")
sampler_channel = grpc.aio.insecure_channel(api_url)
......@@ -283,7 +283,7 @@ if __name__ == "__main__":
"--backend",
type=str,
default="srt",
choices=["vllm", "tgi", "srt", "lightllm", "xinfer"],
choices=["vllm", "tgi", "srt", "lightllm", "ginfer"],
)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
......
......@@ -18,7 +18,7 @@ if __name__ == "__main__":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "xinfer":
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")
......@@ -60,9 +60,9 @@ if __name__ == "__main__":
"max_tokens": max_new_tokens,
},
)
elif args.backend == "xinfer":
elif args.backend == "ginfer":
import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc
from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
......
......@@ -3,7 +3,8 @@
import json
import os
import warnings
from typing import List, Optional, Tuple, Union
import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from huggingface_hub import snapshot_download
from transformers import (
......@@ -177,10 +178,57 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path)
import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name = "tmp-json"
with open(tokenizer_path, "rb") as fin:
tok_dict = json.load(fin)
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
)
else:
default_allowed_special = None
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
def encode_patched(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.vocab_size = tokenizer.n_vocab
def encode(self, x, add_special_tokens=False):
......@@ -190,6 +238,8 @@ class TiktokenTokenizer:
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
......
......@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
return pred
def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None):
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc
from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
......@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser):
"vllm",
"outlines",
"lightllm",
"xinfer",
"ginfer",
"guidance",
"lmql",
"srt-raw",
......@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser):
"lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000,
"xinfer": 9988,
"ginfer": 9988,
}
args.port = default_port.get(args.backend, None)
return args
......@@ -312,8 +312,8 @@ def _get_call_generate(args):
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "xinfer":
return partial(call_generate_xinfer, url=f"{args.host}:{args.port}")
elif args.backend == "ginfer":
return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment