Commit 7a087fdc authored by PanZezhong's avatar PanZezhong
Browse files

support max-token and max-batch args

parent 732c7f04
......@@ -21,5 +21,7 @@ python jiuge.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore]
- 部署模型推理服务
```bash
python launch_server.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]
launch_server.py [-h] [--dev {cpu,nvidia,cambricon,ascend,metax,moore}]
[--model-path MODEL_PATH] [--ndev NDEV] [--max-batch MAX_BATCH]
[--max-tokens MAX_TOKENS]
```
......@@ -80,7 +80,7 @@ class LlamaWeightsNaming:
class JiugeMetaFromLlama(JiugeMetaCStruct):
def __init__(self, config, dtype=torch.float16):
def __init__(self, config, dtype=torch.float16, max_tokens=None):
if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32:
......@@ -99,7 +99,7 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
),
dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"],
dctx=config["max_position_embeddings"],
dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens,
dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
......@@ -352,7 +352,7 @@ class JiugeBatchedTask:
class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1):
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None):
def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {}
dir_path_ = Path(dir_path_)
......@@ -381,7 +381,7 @@ class JiugeForCauslLM:
.cpu()
.half()
)
self.meta = JiugeMetaFromLlama(config)
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl(
self.meta,
......@@ -402,7 +402,7 @@ class JiugeForCauslLM:
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.weights = JiugeWeightsImpl(
self.meta,
LlamaWeightsNaming(),
......@@ -427,7 +427,7 @@ class JiugeForCauslLM:
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.weights = JiugeWeightsImpl(
self.meta,
LlamaWeightsNaming(),
......@@ -443,7 +443,7 @@ class JiugeForCauslLM:
elif "qwen2" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.weights = JiugeWeightsImpl(
self.meta,
LlamaWeightsNaming(),
......
......@@ -3,6 +3,7 @@ from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
import argparse
import queue
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse
......@@ -10,37 +11,65 @@ import contextlib
import uvicorn
import time
import uuid
import sys
import json
import threading
import janus
if len(sys.argv) < 3:
print(
"Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
DEVICE_TYPE_MAP = {
"cpu": DeviceType.DEVICE_TYPE_CPU,
"nvidia": DeviceType.DEVICE_TYPE_NVIDIA,
"cambricon": DeviceType.DEVICE_TYPE_CAMBRICON,
"ascend": DeviceType.DEVICE_TYPE_ASCEND,
"metax": DeviceType.DEVICE_TYPE_METAX,
"moore": DeviceType.DEVICE_TYPE_MOORE,
}
def parse_args():
parser = argparse.ArgumentParser(description="Launch the LLM inference server.")
parser.add_argument(
"--model-path",
type=str,
help="Path to the model directory",
)
parser.add_argument(
"--dev",
type=str,
choices=DEVICE_TYPE_MAP.keys(),
default="cpu",
help="Device type to run the model on (default: cpu)",
)
parser.add_argument(
"--ndev",
type=int,
default=1,
help="Number of devices to use (default: 1)",
)
sys.exit(1)
model_path = sys.argv[2]
device_type = DeviceType.DEVICE_TYPE_CPU
if sys.argv[1] == "--cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif sys.argv[1] == "--nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
else:
print(
"Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
parser.add_argument(
"--max-batch",
type=int,
default=3,
help="Maximum number of requests that can be batched together (default: 3)",
)
sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
parser.add_argument(
"--max-tokens",
type=int,
required=False,
default=None,
help="Max token sequence length that model will handle (follows model config if not provided)",
)
return parser.parse_args()
args = parse_args()
device_type = DEVICE_TYPE_MAP[args.dev]
model_path = args.model_path
ndev = args.ndev
max_tokens = args.max_tokens
MAX_BATCH = args.max_batch
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {}
......@@ -65,11 +94,6 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
}
MAX_BATCH = 3
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
# A wrapper for InferTask that supports async output queue
class AsyncInferTask(InferTask):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
......@@ -85,7 +109,7 @@ class AsyncInferTask(InferTask):
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
app.state.model = JiugeForCauslLM(model_path, device_type, ndev)
app.state.model = JiugeForCauslLM(model_path, device_type, ndev, max_tokens=max_tokens)
app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH)
app.state.request_queue = janus.Queue()
worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment