Commit 3bc01ac1 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

[Minor] improve code style

parent 9f009261
...@@ -149,12 +149,12 @@ async def send_request( ...@@ -149,12 +149,12 @@ async def send_request(
"inputs": prompt, "inputs": prompt,
"parameters": params, "parameters": params,
} }
elif backend == "xinfer": elif backend == "ginfer":
pass pass
else: else:
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
if backend != "xinfer": if backend != "ginfer":
timeout = aiohttp.ClientTimeout(total=3 * 3600) timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
while True: while True:
...@@ -172,7 +172,7 @@ async def send_request( ...@@ -172,7 +172,7 @@ async def send_request(
print(output) print(output)
else: else:
import grpc import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc from ginfer import sampler_pb2, sampler_pb2_grpc
api_url = api_url.replace("http://", "").replace("/generate", "") api_url = api_url.replace("http://", "").replace("/generate", "")
sampler_channel = grpc.aio.insecure_channel(api_url) sampler_channel = grpc.aio.insecure_channel(api_url)
...@@ -283,7 +283,7 @@ if __name__ == "__main__": ...@@ -283,7 +283,7 @@ if __name__ == "__main__":
"--backend", "--backend",
type=str, type=str,
default="srt", default="srt",
choices=["vllm", "tgi", "srt", "lightllm", "xinfer"], choices=["vllm", "tgi", "srt", "lightllm", "ginfer"],
) )
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000) parser.add_argument("--port", type=int, default=30000)
......
...@@ -18,7 +18,7 @@ if __name__ == "__main__": ...@@ -18,7 +18,7 @@ if __name__ == "__main__":
args.port = 21000 args.port = 21000
elif args.backend == "lightllm": elif args.backend == "lightllm":
args.port = 22000 args.port = 22000
elif args.backend == "xinfer": elif args.backend == "ginfer":
args.port = 9988 args.port = 9988
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
...@@ -60,9 +60,9 @@ if __name__ == "__main__": ...@@ -60,9 +60,9 @@ if __name__ == "__main__":
"max_tokens": max_new_tokens, "max_tokens": max_new_tokens,
}, },
) )
elif args.backend == "xinfer": elif args.backend == "ginfer":
import grpc import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", "")) sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import json import json
import os import os
import warnings import warnings
from typing import List, Optional, Tuple, Union import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
...@@ -177,10 +178,57 @@ def get_processor( ...@@ -177,10 +178,57 @@ def get_processor(
class TiktokenTokenizer: class TiktokenTokenizer:
def __init__(self, tokenizer_path): def __init__(self, tokenizer_path):
import xlm.tokenizers.tiktoken_wrapper as tiktoken_wrapper import tiktoken
tokenizer = tiktoken_wrapper.Encoding.from_xtok_json("xtok-json", tokenizer_path) PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name = "tmp-json"
with open(tokenizer_path, "rb") as fin:
tok_dict = json.load(fin)
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
)
else:
default_allowed_special = None
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
def encode_patched(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_token self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.vocab_size = tokenizer.n_vocab self.vocab_size = tokenizer.n_vocab
def encode(self, x, add_special_tokens=False): def encode(self, x, add_special_tokens=False):
...@@ -190,6 +238,8 @@ class TiktokenTokenizer: ...@@ -190,6 +238,8 @@ class TiktokenTokenizer:
return self.tokenizer.decode(x) return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False): def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch) return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index): def convert_ids_to_tokens(self, index):
......
...@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): ...@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
return pred return pred
def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None): def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
import grpc import grpc
from xlm.proto import sampler_pb2, sampler_pb2_grpc from ginfer import sampler_pb2, sampler_pb2_grpc
sampler_channel = grpc.insecure_channel(url.replace("http://", "")) sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
...@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser): ...@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser):
"vllm", "vllm",
"outlines", "outlines",
"lightllm", "lightllm",
"xinfer", "ginfer",
"guidance", "guidance",
"lmql", "lmql",
"srt-raw", "srt-raw",
...@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser): ...@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser):
"lightllm": 22000, "lightllm": 22000,
"lmql": 23000, "lmql": 23000,
"srt-raw": 30000, "srt-raw": 30000,
"xinfer": 9988, "ginfer": 9988,
} }
args.port = default_port.get(args.backend, None) args.port = default_port.get(args.backend, None)
return args return args
...@@ -312,8 +312,8 @@ def _get_call_generate(args): ...@@ -312,8 +312,8 @@ def _get_call_generate(args):
return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate") return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "srt-raw": elif args.backend == "srt-raw":
return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate") return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate")
elif args.backend == "xinfer": elif args.backend == "ginfer":
return partial(call_generate_xinfer, url=f"{args.host}:{args.port}") return partial(call_generate_ginfer, url=f"{args.host}:{args.port}")
elif args.backend == "outlines": elif args.backend == "outlines":
return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate") return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate")
elif args.backend == "guidance": elif args.backend == "guidance":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment