Commit 7a087fdc authored by PanZezhong's avatar PanZezhong
Browse files

support max-token and max-batch args

parent 732c7f04
...@@ -21,5 +21,7 @@ python jiuge.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore] ...@@ -21,5 +21,7 @@ python jiuge.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore]
- 部署模型推理服务 - 部署模型推理服务
```bash ```bash
python launch_server.py [--cpu | --nvidia | --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device] launch_server.py [-h] [--dev {cpu,nvidia,cambricon,ascend,metax,moore}]
[--model-path MODEL_PATH] [--ndev NDEV] [--max-batch MAX_BATCH]
[--max-tokens MAX_TOKENS]
``` ```
...@@ -80,7 +80,7 @@ class LlamaWeightsNaming: ...@@ -80,7 +80,7 @@ class LlamaWeightsNaming:
class JiugeMetaFromLlama(JiugeMetaCStruct): class JiugeMetaFromLlama(JiugeMetaCStruct):
def __init__(self, config, dtype=torch.float16): def __init__(self, config, dtype=torch.float16, max_tokens=None):
if dtype == torch.float16: if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32: elif dtype == torch.float32:
...@@ -99,7 +99,7 @@ class JiugeMetaFromLlama(JiugeMetaCStruct): ...@@ -99,7 +99,7 @@ class JiugeMetaFromLlama(JiugeMetaCStruct):
), ),
dh=config["hidden_size"] // config["num_attention_heads"], dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"], di=config["intermediate_size"],
dctx=config["max_position_embeddings"], dctx=config["max_position_embeddings"] if max_tokens is None else max_tokens,
dvoc=config["vocab_size"], dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"], epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
...@@ -352,7 +352,7 @@ class JiugeBatchedTask: ...@@ -352,7 +352,7 @@ class JiugeBatchedTask:
class JiugeForCauslLM: class JiugeForCauslLM:
def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1): def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None):
def load_all_safetensors_from_dir(dir_path_: str): def load_all_safetensors_from_dir(dir_path_: str):
tensors_ = {} tensors_ = {}
dir_path_ = Path(dir_path_) dir_path_ = Path(dir_path_)
...@@ -381,7 +381,7 @@ class JiugeForCauslLM: ...@@ -381,7 +381,7 @@ class JiugeForCauslLM:
.cpu() .cpu()
.half() .half()
) )
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, self.meta,
...@@ -402,7 +402,7 @@ class JiugeForCauslLM: ...@@ -402,7 +402,7 @@ class JiugeForCauslLM:
map_location="cpu", map_location="cpu",
) )
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, self.meta,
LlamaWeightsNaming(), LlamaWeightsNaming(),
...@@ -427,7 +427,7 @@ class JiugeForCauslLM: ...@@ -427,7 +427,7 @@ class JiugeForCauslLM:
map_location="cpu", map_location="cpu",
) )
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, self.meta,
LlamaWeightsNaming(), LlamaWeightsNaming(),
...@@ -443,7 +443,7 @@ class JiugeForCauslLM: ...@@ -443,7 +443,7 @@ class JiugeForCauslLM:
elif "qwen2" == config["model_type"]: elif "qwen2" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path) state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
self.meta, self.meta,
LlamaWeightsNaming(), LlamaWeightsNaming(),
......
...@@ -3,6 +3,7 @@ from libinfinicore_infer import DeviceType ...@@ -3,6 +3,7 @@ from libinfinicore_infer import DeviceType
from infer_task import InferTask from infer_task import InferTask
from kvcache_pool import KVCachePool from kvcache_pool import KVCachePool
import argparse
import queue import queue
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
...@@ -10,37 +11,65 @@ import contextlib ...@@ -10,37 +11,65 @@ import contextlib
import uvicorn import uvicorn
import time import time
import uuid import uuid
import sys
import json import json
import threading import threading
import janus import janus
if len(sys.argv) < 3:
print( DEVICE_TYPE_MAP = {
"Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]" "cpu": DeviceType.DEVICE_TYPE_CPU,
"nvidia": DeviceType.DEVICE_TYPE_NVIDIA,
"cambricon": DeviceType.DEVICE_TYPE_CAMBRICON,
"ascend": DeviceType.DEVICE_TYPE_ASCEND,
"metax": DeviceType.DEVICE_TYPE_METAX,
"moore": DeviceType.DEVICE_TYPE_MOORE,
}
def parse_args():
parser = argparse.ArgumentParser(description="Launch the LLM inference server.")
parser.add_argument(
"--model-path",
type=str,
help="Path to the model directory",
)
parser.add_argument(
"--dev",
type=str,
choices=DEVICE_TYPE_MAP.keys(),
default="cpu",
help="Device type to run the model on (default: cpu)",
)
parser.add_argument(
"--ndev",
type=int,
default=1,
help="Number of devices to use (default: 1)",
) )
sys.exit(1) parser.add_argument(
model_path = sys.argv[2] "--max-batch",
device_type = DeviceType.DEVICE_TYPE_CPU type=int,
if sys.argv[1] == "--cpu": default=3,
device_type = DeviceType.DEVICE_TYPE_CPU help="Maximum number of requests that can be batched together (default: 3)",
elif sys.argv[1] == "--nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
else:
print(
"Usage: python launch_server.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
) )
sys.exit(1) parser.add_argument(
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 "--max-tokens",
type=int,
required=False,
default=None,
help="Max token sequence length that model will handle (follows model config if not provided)",
)
return parser.parse_args()
args = parse_args()
device_type = DEVICE_TYPE_MAP[args.dev]
model_path = args.model_path
ndev = args.ndev
max_tokens = args.max_tokens
MAX_BATCH = args.max_batch
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
def chunk_json(id_, content=None, role=None, finish_reason=None): def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {} delta = {}
...@@ -65,11 +94,6 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): ...@@ -65,11 +94,6 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
} }
MAX_BATCH = 3
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
# A wrapper for InferTask that supports async output queue # A wrapper for InferTask that supports async output queue
class AsyncInferTask(InferTask): class AsyncInferTask(InferTask):
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
...@@ -85,7 +109,7 @@ class AsyncInferTask(InferTask): ...@@ -85,7 +109,7 @@ class AsyncInferTask(InferTask):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
app.state.model = JiugeForCauslLM(model_path, device_type, ndev) app.state.model = JiugeForCauslLM(model_path, device_type, ndev, max_tokens=max_tokens)
app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH)
app.state.request_queue = janus.Queue() app.state.request_queue = janus.Queue()
worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment