"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "97f7b1df8688eac14b524c2be51340d4b48809fe"
Unverified Commit c73ff203 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/189: add inference server support to InfiniLM (#190)

parents de3e6b95 97870d3e
...@@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA ...@@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python examples/jiuge.py --nvidia --model_path=/models/9G7B_MHA/ --backend=cpp --tp=4 --batch_size=16 python examples/jiuge.py --nvidia --model_path=/models/9G7B_MHA/ --backend=cpp --tp=4 --batch_size=16
``` ```
- 推理服务测试
- 启动推理服务
```bash
python python/infinilm/server/inference_server.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] --model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH --tp=NDEV --temperature=TEMP --top_p=TOP_P --top_k=TOP_K --host=HOST --port=PORT
```
- 单卡示例:
```bash
CUDA_VISIBLE_DEVICES=0 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1
```
- 多卡分布式示例:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=4 --temperature=1.0 --top_p=0.8 --top_k=1
```
- 测试推理服务性能:
```bash
python scripts/test_perf.py --verbose
```
- 运行推理基准测试(C-Eval/MMLU) - 运行推理基准测试(C-Eval/MMLU)
```bash ```bash
......
from .models import AutoLlamaModel from .models import AutoLlamaModel
from . import distributed from . import distributed
from . import cache from . import cache
from . import llm
__all__ = ["AutoLlamaModel", "distributed", "cache"] from .llm import (
LLM,
AsyncLLMEngine,
SamplingParams,
RequestOutput,
TokenOutput,
)
__all__ = [
"AutoLlamaModel",
"distributed",
"cache",
"llm",
# LLM classes
"LLM",
"AsyncLLMEngine",
"SamplingParams",
"RequestOutput",
"TokenOutput",
]
"""
InfiniLM Engine - High-performance llm inference engine with batch generation and streaming support.
"""
from infinilm.llm.sampling_params import SamplingParams
from infinilm.llm.request import (
RequestStatus,
FinishReason,
RequestOutput,
CompletionOutput,
TokenOutput,
InferenceRequest,
)
from infinilm.llm.llm import (
LLM,
LLMEngine,
AsyncLLMEngine,
EngineConfig,
)
from infinilm.llm.scheduler import Scheduler, SchedulerOutput
from infinilm.llm.cache_manager import BlockManager, Block
__all__ = [
# Main classes
"LLM",
"AsyncLLMEngine",
"LLMEngine",
"EngineConfig",
# Parameters
"SamplingParams",
# Request and Output
"InferenceRequest",
"RequestOutput",
"CompletionOutput",
"TokenOutput",
"RequestStatus",
"FinishReason",
# Internal (for advanced use)
"Scheduler",
"SchedulerOutput",
"BlockManager",
"Block",
]
"""
KV Cache Manager - Paged Attention block-based cache allocation and management.
"""
from collections import deque
from typing import List, Dict, Set
import xxhash
import numpy as np
class Block:
"""KV Cache Block with reference counting and hash-based reuse support."""
def __init__(self, block_id: int):
self.block_id = block_id
self.ref_count = 0
self.hash = -1
self.token_ids: List[int] = []
def update(self, hash_value: int, token_ids: List[int]) -> None:
self.hash = hash_value
self.token_ids = token_ids.copy()
def reset(self) -> None:
self.ref_count = 1
self.hash = -1
self.token_ids = []
def free(self) -> None:
self.ref_count = 0
self.hash = -1
self.token_ids = []
def __repr__(self) -> str:
return f"Block(id={self.block_id}, ref={self.ref_count}, hash={self.hash})"
class BlockManager:
"""Manages Paged KV Cache allocation with prefix caching support.
Features:
- Block allocation/deallocation with reference counting
- Hash-based prefix caching for token sequence reuse
- Slot mapping generation for physical-to-logical position mapping
"""
def __init__(self, num_blocks: int, block_size: int):
assert (
num_blocks > 0 and block_size > 0
), "num_blocks and block_size must be positive"
self.num_blocks = num_blocks
self.block_size = block_size
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: Dict[int, int] = {}
self.free_block_ids: deque = deque(range(num_blocks))
self.used_block_ids: Set[int] = set()
self.req_block_ids: Set[int] = set()
def reset_req_blocks(self) -> None:
"""Move blocks from prefill stage to used blocks and update hash mappings."""
for block_id in self.req_block_ids:
self.used_block_ids.add(block_id)
block = self.blocks[block_id]
prefix_hash = block.hash
self.hash_to_block_id[prefix_hash] = block_id
self.req_block_ids.clear()
@classmethod
def compute_hash(cls, token_ids: List[int], prefix_hash: int = -1) -> int:
"""Compute hash for token sequence with optional prefix chaining."""
h = xxhash.xxh64()
if prefix_hash != -1:
h.update(prefix_hash.to_bytes(8, "little"))
h.update(np.array(token_ids, dtype=np.int32).tobytes())
return h.intdigest()
def _allocate_partial_block(self, block_id: int) -> Block:
"""Allocate an incomplete block and add to used blocks."""
assert block_id in self.free_block_ids, f"Block {block_id} not in free list"
block = self.blocks[block_id]
assert block.ref_count == 0, f"Block {block_id} ref_count not zero"
block.reset()
self.free_block_ids.remove(block_id)
self.used_block_ids.add(block_id)
return block
def _allocate_full_block(self, block_id: int) -> Block:
"""Allocate a complete block and add to request blocks."""
assert block_id in self.free_block_ids, f"Block {block_id} not in free list"
block = self.blocks[block_id]
assert block.ref_count == 0, f"Block {block_id} ref_count not zero"
block.reset()
self.free_block_ids.remove(block_id)
self.req_block_ids.add(block_id)
return block
def _deallocate_block(self, block_id: int):
"""Deallocate a block and return it to free list."""
block = self.blocks[block_id]
assert (
block.ref_count == 0
), f"Block {block_id} ref_count not zero, cannot deallocate"
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.free()
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
def can_allocate(self, num_required_blocks: int) -> bool:
return len(self.free_block_ids) >= num_required_blocks
def allocate_blocks(
self, token_ids: List[int], block_table: List[int] = None
) -> tuple[List[int], List[int], int]:
"""Allocate cache blocks for new request with prefix caching support.
Args:
token_ids: Input token sequence
block_table: Existing block_table (for decode phase)
Returns:
Tuple of (block_table, slot_mapping, num_cached_tokens)
"""
if block_table is None:
block_table = []
num_tokens = len(token_ids)
num_blocks = (num_tokens + self.block_size - 1) // self.block_size
slot_mapping = []
num_cached_tokens = 0
prefix_hash = -1
cache_miss = False
for block_idx in range(num_blocks):
start_idx = block_idx * self.block_size
end_idx = min(start_idx + self.block_size, num_tokens)
block_tokens = token_ids[start_idx:end_idx]
# Only full blocks can be hashed for reuse
if len(block_tokens) == self.block_size:
prefix_hash = self.compute_hash(block_tokens, prefix_hash)
# Try to reuse existing block
if not cache_miss:
cached_block_id = self.hash_to_block_id.get(prefix_hash, -1)
if (
cached_block_id != -1
and self.blocks[cached_block_id].token_ids == block_tokens
):
# Check if all tokens are cached
if num_cached_tokens + self.block_size == len(token_ids):
cache_miss = True
else:
# Reuse successful
block = self.blocks[cached_block_id]
block.ref_count += 1
block_table.append(cached_block_id)
num_cached_tokens += self.block_size
continue
else:
cache_miss = True
else:
prefix_hash = -1
# Cannot reuse, allocate new block
if not self.free_block_ids:
raise RuntimeError("No available cache blocks")
new_block_id = self.free_block_ids[0]
if prefix_hash != -1:
block = self._allocate_full_block(new_block_id)
block.update(prefix_hash, block_tokens)
else:
block = self._allocate_partial_block(new_block_id)
block_table.append(new_block_id)
# Generate slot_mapping
for i in range(len(block_tokens)):
slot_mapping.append(new_block_id * self.block_size + i)
return block_table, slot_mapping, num_cached_tokens
def append_slot(
self, block_table: List[int], num_tokens: int, total_token_ids: List[int] = None
) -> tuple[List[int], int]:
"""Append slot for decode phase (generate one new token).
Args:
block_table: Current block_table
num_tokens: Current total token count (including newly generated token)
total_token_ids: All token sequence (for updating block hash)
Returns:
Tuple of (block_table, slot_id)
"""
assert len(block_table) > 0, "block_table cannot be empty"
assert num_tokens > 0, "num_tokens must be greater than 0"
if num_tokens % self.block_size == 1:
# Previous block is full, update its hash for future prefix caching
last_block_id = block_table[-1]
last_block = self.blocks[last_block_id]
# Only update if block's token_ids is empty (avoid duplicate updates)
if len(last_block.token_ids) == 0:
block_start_idx = num_tokens - self.block_size - 1
block_end_idx = num_tokens - 1
block_tokens = total_token_ids[block_start_idx:block_end_idx]
# Compute prefix_hash using previous block's hash if available
if len(block_table) > 1:
prev_block = self.blocks[block_table[-2]]
prefix_hash = prev_block.hash
else:
prefix_hash = -1
current_hash = self.compute_hash(block_tokens, prefix_hash)
last_block.update(current_hash, block_tokens)
self.hash_to_block_id[current_hash] = last_block_id
# Need new block
if not self.free_block_ids:
if not self.try_free_blocks(1):
raise RuntimeError("No available cache blocks")
new_block_id = self.free_block_ids[0]
self._allocate_partial_block(new_block_id)
block_table.append(new_block_id)
# Calculate slot
last_block_id = block_table[-1]
offset = (num_tokens - 1) % self.block_size
slot_id = last_block_id * self.block_size + offset
return block_table, slot_id
def free_blocks(self, block_table: List[int]):
"""Decrease reference count for all blocks. Blocks with ref_count=0 are not
immediately freed to allow reuse."""
for block_id in reversed(block_table):
block = self.blocks[block_id]
block.ref_count -= 1
def try_free_blocks(self, num_required: int) -> bool:
"""Try to free blocks with ref_count=0."""
to_free = [
bid for bid in self.used_block_ids if self.blocks[bid].ref_count == 0
]
for block_id in to_free:
self._deallocate_block(block_id)
if self.can_allocate(num_required):
return True
return self.can_allocate(num_required)
def get_num_free_blocks(self) -> int:
return len(self.free_block_ids)
def __repr__(self):
return (
f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, "
f"free={len(self.free_block_ids)}, used={len(self.used_block_ids)})"
)
"""
LLM Engine - Main interface for LLM inference.
This module provides:
- LLM class for batch generation (offline use)
- AsyncLLM class for asynchronous streaming (server use)
"""
import time
import uuid
import logging
import threading
from typing import List, Optional, Union, AsyncIterator
from dataclasses import dataclass
import infinicore
from infinilm.llm.request import (
InferenceRequest,
RequestOutput,
TokenOutput,
FinishReason,
)
from infinilm.llm.sampling_params import SamplingParams
from infinilm.llm.scheduler import Scheduler
from infinilm.distributed import DistConfig
from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
logger = logging.getLogger(__name__)
@dataclass
class EngineConfig:
"""Configuration for LLM Engine.
Attributes:
model_path: Path to the model directory.
device: Device type string ('cpu', 'cuda', 'mlu', etc.).
dtype: Data type string ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
"""
model_path: str
device: str = "cuda"
dtype: str = "float16"
tensor_parallel_size: int = 1
max_batch_size: int = 16
max_tokens: int = 4096
num_blocks: int = 8 * 1024
block_size: int = 16
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
class LLMEngine:
"""Low-level LLM engine that handles inference execution."""
def __init__(self, config: EngineConfig):
self.config = config
# Initialize device and dtype
self._init_device()
# Initialize model engine
self.model_engine = InferEngine(
model_path=config.model_path,
device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size),
)
# Load model weights
load_model_state_dict_by_file(
self.model_engine, config.model_path, dtype=self.model_engine.config.dtype
)
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_path, trust_remote_code=True
)
self._fix_tokenizer_decoder()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
self.model_engine.reset_cache(cache_config)
# Initialize scheduler
self.scheduler = Scheduler(
max_batch_size=config.max_batch_size,
num_blocks=config.num_blocks,
block_size=config.block_size,
)
# Get EOS token IDs from model config
self.eos_token_ids = self.model_engine.config.eos_token_id or []
if isinstance(self.eos_token_ids, int):
self.eos_token_ids = [self.eos_token_ids]
logger.info(
f"LLMEngine initialized with model at {config.model_path} "
f"on device {config.device}"
)
def _init_device(self):
"""Initialize infinicore device and dtype."""
supported_devices = ["cpu", "cuda", "mlu", "moore"]
device_str = self.config.device
if device_str not in supported_devices:
raise ValueError(
f"Unsupported device: '{device_str}'. "
f"Supported devices: {supported_devices}"
)
self.device = infinicore.device(device_str, 0)
dtype_map = {
"float32": infinicore.float32,
"float16": infinicore.float16,
"bfloat16": infinicore.bfloat16,
}
if self.config.dtype not in dtype_map:
raise ValueError(
f"Unsupported dtype: '{self.config.dtype}'. "
f"Supported dtypes: {list(dtype_map.keys())}"
)
self.dtype = dtype_map[self.config.dtype]
def _fix_tokenizer_decoder(self):
"""Fix tokenizer decoder for llama models."""
if "llama" in self.model_engine.config.model_type.lower():
backend = getattr(self.tokenizer, "backend_tokenizer", None)
target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None)
dec = getattr(target, "decoder", None)
sn = repr(norm)[:800] if norm is not None else ""
sd = repr(dec)[:800] if dec is not None else ""
has_prepend = "Prepend" in sn
has_strip = "Strip" in sd
if has_prepend and has_strip:
target.decoder = _dec.Sequence(
[
_dec.Replace("▁", " "),
_dec.ByteFallback(),
_dec.Fuse(),
]
)
def add_request(self, request: InferenceRequest):
"""Add a request to the scheduler."""
self.scheduler.add_request(request)
def step(self) -> List[InferenceRequest]:
"""Run one inference step.
Returns:
List of requests that were processed in this step.
"""
# Schedule requests
scheduler_output = self.scheduler.schedule()
if scheduler_output is None or not scheduler_output.scheduled_requests:
return []
# Build model inputs
model_input_dict = scheduler_output.build_model_inputs(
self.config.temperature, self.config.top_p, self.config.top_k
)
model_input = self._prepare_model_input(model_input_dict)
# Run inference
sampled_tokens = self.model_engine.forward(**model_input)
sampled_tokens_list = sampled_tokens.to_numpy().tolist()
# Update request status
self._update_requests(
scheduler_output.is_prefill,
scheduler_output.scheduled_requests,
sampled_tokens_list,
)
return scheduler_output.scheduled_requests
def _prepare_model_input(self, model_input_dict: dict) -> dict:
"""Convert model input dict to infinicore tensors."""
model_input = {}
for key, value in model_input_dict.items():
if key == "input_ids":
model_input[key] = infinicore.from_list([value], dtype=infinicore.int64)
elif key in [
"position_ids",
"past_kv_lengths",
"total_kv_lengths",
"input_offsets",
"slot_mapping",
]:
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
elif key == "block_tables":
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
else:
model_input[key] = value
return model_input
def _update_requests(
self,
is_prefill: bool,
requests: List[InferenceRequest],
sampled_tokens: List[int],
):
"""Update request status after inference step."""
if is_prefill:
self.scheduler.cache_manager.reset_req_blocks()
for req, token_id in zip(requests, sampled_tokens):
req.generated_token_ids.append(token_id)
if req.is_prefill:
req.is_prefill = False
token_text = self.tokenizer.decode(token_id)
req.generated_text += token_text
if self._check_request_finished(req, token_id):
req.mark_finished(req.finish_reason)
# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
output = TokenOutput(
request_id=req.request_id,
token_id=token_id,
token_text=token_text,
finished=req.is_finished(),
finish_reason=req.finish_reason,
generated_text=req.generated_text,
)
req.output_queue.sync_q.put(output)
self.scheduler.complete_requests(requests)
def _check_request_finished(self, req: InferenceRequest, token_id: int) -> bool:
"""Check if request generation is finished."""
max_tokens = req.sampling_params.max_tokens
if max_tokens and req.get_num_generated_tokens() >= max_tokens:
req.finish_reason = FinishReason.LENGTH
return True
# Check EOS token
eos_ids = req.eos_token_ids or self.eos_token_ids
if eos_ids and token_id in eos_ids:
req.finish_reason = FinishReason.EOS_TOKEN
return True
# Check stop strings
stop_strings = req.sampling_params.stop or []
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
req.finish_reason = FinishReason.STOP_STRING
return True
return False
def tokenize(self, text: str) -> List[int]:
"""Tokenize text to token IDs."""
return self.tokenizer.encode(text)
def detokenize(self, token_ids: List[int]) -> str:
"""Detokenize token IDs to text."""
return self.tokenizer.decode(token_ids)
def apply_chat_template(
self,
messages: List[dict],
add_generation_prompt: bool = True,
) -> str:
"""Apply chat template to messages."""
return self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=add_generation_prompt,
tokenize=False,
)
class LLM:
"""High-level LLM interface for batch generation."""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float16",
tensor_parallel_size: int = 1,
max_batch_size: int = 16,
max_tokens: int = 4096,
num_blocks: int = 8 * 1024,
block_size: int = 16,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
):
"""Initialize LLM.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
"""
config = EngineConfig(
model_path=model_path,
device=device,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
max_batch_size=max_batch_size,
max_tokens=max_tokens,
num_blocks=num_blocks,
block_size=block_size,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
self.engine = LLMEngine(config)
self.config = config
def generate(
self,
prompts: Union[str, List[str]],
sampling_params: Optional[SamplingParams] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generate completions for the given prompts.
Args:
prompts: A single prompt string or list of prompt strings.
sampling_params: Sampling parameters for generation.
use_tqdm: Whether to show progress bar.
Returns:
List of RequestOutput objects containing generated text.
"""
if isinstance(prompts, str):
prompts = [prompts]
if sampling_params is None:
sampling_params = SamplingParams(max_tokens=self.config.max_tokens)
elif sampling_params.max_tokens is None:
sampling_params = sampling_params.clone()
sampling_params.max_tokens = self.config.max_tokens
requests = []
for prompt in prompts:
request_id = f"cmpl-{uuid.uuid4().hex}"
token_ids = self.engine.tokenize(prompt)
req = InferenceRequest(
request_id=request_id,
prompt=prompt,
prompt_token_ids=token_ids,
sampling_params=sampling_params,
eos_token_ids=self.engine.eos_token_ids,
)
requests.append(req)
self.engine.add_request(req)
# Run inference until all requests are finished
if use_tqdm:
try:
from tqdm import tqdm
pbar = tqdm(total=len(requests), desc="Generating")
except ImportError:
pbar = None
use_tqdm = False
else:
pbar = None
finished_count = 0
while finished_count < len(requests):
self.engine.step()
new_finished = sum(1 for req in requests if req.is_finished())
if use_tqdm and pbar and new_finished > finished_count:
pbar.update(new_finished - finished_count)
finished_count = new_finished
if pbar:
pbar.close()
outputs = [req.to_request_output() for req in requests]
return outputs
def chat(
self,
messages: Union[List[dict], List[List[dict]]],
sampling_params: Optional[SamplingParams] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generate chat completions for the given messages.
Args:
messages: A single conversation (list of message dicts) or
a list of conversations.
sampling_params: Sampling parameters for generation.
use_tqdm: Whether to show progress bar.
Returns:
List of RequestOutput objects containing generated responses.
"""
if messages and isinstance(messages[0], dict):
messages = [messages]
prompts = []
for conversation in messages:
prompt = self.engine.apply_chat_template(
conversation, add_generation_prompt=True
)
prompts.append(prompt)
return self.generate(prompts, sampling_params, use_tqdm)
class AsyncLLMEngine:
"""Asynchronous LLM engine for server use with streaming support."""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float16",
tensor_parallel_size: int = 1,
max_batch_size: int = 16,
max_tokens: int = 512,
num_blocks: int = 8 * 1024,
block_size: int = 16,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
):
"""Initialize AsyncLLMEngine.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
"""
config = EngineConfig(
model_path=model_path,
device=device,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
max_batch_size=max_batch_size,
max_tokens=max_tokens,
num_blocks=num_blocks,
block_size=block_size,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
self.engine = LLMEngine(config)
self.config = config
self._running = False
self._step_thread: Optional[threading.Thread] = None
def start(self):
"""Start the background inference loop."""
if self._running:
logger.warning("AsyncLLMEngine is already running")
return
self._running = True
self._step_thread = threading.Thread(
target=self._step_loop, daemon=True, name="AsyncLLMEngineStepThread"
)
self._step_thread.start()
logger.info("AsyncLLMEngine started")
def stop(self):
"""Stop the background inference loop."""
if not self._running:
logger.warning("AsyncLLMEngine is not running")
return
self._running = False
if self._step_thread:
self._step_thread.join(timeout=5)
logger.info("AsyncLLMEngine stopped")
def _step_loop(self):
"""Background loop that runs inference steps."""
while self._running:
try:
requests = self.engine.step()
if not requests:
time.sleep(0.01)
except Exception as e:
logger.error(f"Error in step loop: {e}", exc_info=True)
self._running = False
break
def add_request(
self,
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
# For server use
request_data: Optional[dict] = None,
http_request: Optional[any] = None,
) -> InferenceRequest:
"""Add a request to the engine.
Args:
prompt: Text prompt for generation.
prompt_token_ids: Pre-tokenized prompt.
sampling_params: Sampling parameters.
request_id: Optional request ID.
request_data: Optional request data dict (for server use).
http_request: Optional HTTP request object (for server use).
Returns:
The created InferenceRequest object.
"""
if request_id is None:
request_id = f"cmpl-{uuid.uuid4().hex}"
if prompt_token_ids is None and prompt is not None:
prompt_token_ids = self.engine.tokenize(prompt)
if sampling_params is None:
sampling_params = SamplingParams(max_tokens=self.config.max_tokens)
elif sampling_params.max_tokens is None:
sampling_params = sampling_params.clone()
sampling_params.max_tokens = self.config.max_tokens
request = InferenceRequest(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
eos_token_ids=self.engine.eos_token_ids,
request_data=request_data,
http_request=http_request,
)
# Initialize output queue for streaming
_ = request.output_queue
self.engine.add_request(request)
return request
def add_chat_request(
self,
messages: List[dict],
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
request_data: Optional[dict] = None,
http_request: Optional[any] = None,
) -> InferenceRequest:
"""Add a chat request to the engine.
Args:
messages: List of message dicts (chat conversation).
sampling_params: Sampling parameters.
request_id: Optional request ID.
request_data: Optional request data dict.
http_request: Optional HTTP request object.
Returns:
The created InferenceRequest object.
"""
prompt = self.engine.apply_chat_template(messages, add_generation_prompt=True)
return self.add_request(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
request_data=request_data,
http_request=http_request,
)
async def stream_request(
self,
request: InferenceRequest,
timeout: float = 100.0,
) -> AsyncIterator[TokenOutput]:
"""Stream tokens from a request.
Args:
request: The inference request to stream from.
timeout: Timeout for waiting on each token.
Yields:
TokenOutput objects for each generated token.
"""
import asyncio
while True:
if request.is_finished() and request.output_queue.async_q.empty():
break
try:
token_output = await asyncio.wait_for(
request.output_queue.async_q.get(), timeout=timeout
)
request.output_queue.async_q.task_done()
yield token_output
if token_output.finished:
break
except asyncio.TimeoutError:
if request.is_finished():
break
continue
except asyncio.CancelledError:
request.mark_canceled()
break
except Exception as e:
logger.error(f"Error streaming request {request.request_id}: {e}")
await asyncio.sleep(0.01)
"""
Request and Output - Data structures for inference requests and outputs.
"""
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Optional, Any
import time
import janus
from infinilm.llm.sampling_params import SamplingParams
class RequestStatus(Enum):
"""Status of an inference request."""
WAITING = "waiting"
RUNNING = "running"
FINISHED = "finished"
CANCELED = "canceled"
FAILED = "failed"
TIMEOUT = "timeout"
class FinishReason(Enum):
"""Reason for finishing generation."""
STOP = "stop"
LENGTH = "length"
EOS_TOKEN = "eos_token"
STOP_STRING = "stop_string"
TIMEOUT = "timeout"
CANCELED = "canceled"
ERROR = "error"
@dataclass
class RequestOutput:
"""Output from a single generation request.
Attributes:
request_id: Unique identifier for the request.
prompt: Original prompt text.
prompt_token_ids: Token IDs of the prompt.
outputs: List of generated outputs (for beam search, multiple outputs possible).
finished: Whether generation is complete.
finish_reason: Reason for finishing.
"""
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
outputs: List["CompletionOutput"] = field(default_factory=list)
finished: bool = False
finish_reason: Optional[FinishReason] = None
@dataclass
class CompletionOutput:
"""Single completion output.
Attributes:
index: Index of this output (for beam search).
text: Generated text.
token_ids: Generated token IDs.
finish_reason: Reason for finishing.
"""
index: int = 0
text: str = ""
token_ids: List[int] = field(default_factory=list)
finish_reason: Optional[FinishReason] = None
@dataclass
class TokenOutput:
"""Output for a single generated token.
Attributes:
request_id: Unique identifier for the request.
token_id: Generated token ID.
token_text: Decoded text of the token.
finished: Whether generation is complete.
finish_reason: Reason for finishing.
generated_text: Full generated text so far.
"""
request_id: str
token_id: int
token_text: str
finished: bool = False
finish_reason: Optional[FinishReason] = None
generated_text: str = ""
class InferenceRequest:
"""Internal inference request object for managing generation state and resources."""
def __init__(
self,
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[List[int]] = None,
sampling_params: Optional[SamplingParams] = None,
eos_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
# For server use
request_data: Optional[dict] = None,
http_request: Optional[Any] = None,
):
# Request metadata
self.request_id: str = request_id
self.prompt: Optional[str] = prompt
self.prompt_token_ids: List[int] = prompt_token_ids or []
self.prompt_length: int = len(self.prompt_token_ids)
self.arrival_time: float = arrival_time or time.time()
self.finished_time: Optional[float] = None
# Sampling parameters
self.sampling_params: SamplingParams = sampling_params or SamplingParams()
# EOS token IDs (from model config)
self.eos_token_ids: List[int] = eos_token_ids or []
# Generation state
self.generated_token_ids: List[int] = []
self.generated_text: str = ""
self.is_prefill: bool = True
self.status: RequestStatus = RequestStatus.WAITING
self.finish_reason: Optional[FinishReason] = None
self.priority: int = 0
# KV cache management
self.cache_id: Optional[int] = None
self.block_table: List[int] = []
self.slot_mapping: List[int] = []
self.num_cached_tokens: int = 0
self.num_blocks: int = 0
# For server use
self.request_data: Optional[dict] = request_data
self.http_request: Optional[Any] = http_request
# Output management (for async streaming)
self._output_queue: Optional[janus.Queue] = None
@property
def output_queue(self) -> janus.Queue:
"""Lazy initialization of output queue."""
if self._output_queue is None:
self._output_queue = janus.Queue()
return self._output_queue
def get_prompt_length(self) -> int:
return self.prompt_length
def get_input_tokens(self) -> List[int]:
return self.prompt_token_ids
def get_num_generated_tokens(self) -> int:
return len(self.generated_token_ids)
def get_total_length(self) -> int:
return self.prompt_length + len(self.generated_token_ids)
def get_all_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.generated_token_ids
def get_num_blocks_required(self, block_size: int) -> int:
total_tokens = self.get_total_length()
return (total_tokens + block_size - 1) // block_size
def get_max_tokens(self) -> Optional[int]:
return self.sampling_params.max_tokens
def is_finished(self) -> bool:
return self.status in [
RequestStatus.FINISHED,
RequestStatus.CANCELED,
RequestStatus.FAILED,
RequestStatus.TIMEOUT,
]
def mark_finished(self, reason: FinishReason):
"""Mark the request as finished with the given reason."""
self.status = RequestStatus.FINISHED
self.finish_reason = reason
self.finished_time = time.time()
def mark_failed(self, reason: FinishReason = FinishReason.ERROR):
"""Mark the request as failed."""
self.status = RequestStatus.FAILED
self.finish_reason = reason
self.finished_time = time.time()
def mark_canceled(self):
"""Mark the request as canceled."""
self.status = RequestStatus.CANCELED
self.finish_reason = FinishReason.CANCELED
self.finished_time = time.time()
def mark_timeout(self):
"""Mark the request as timed out."""
self.status = RequestStatus.TIMEOUT
self.finish_reason = FinishReason.TIMEOUT
self.finished_time = time.time()
async def close(self):
"""Close the output queue and clean up resources."""
if self._output_queue is not None:
await self._output_queue.async_q.join()
self._output_queue.close()
await self._output_queue.wait_closed()
def to_request_output(self) -> RequestOutput:
"""Convert to RequestOutput for external use."""
return RequestOutput(
request_id=self.request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
outputs=[
CompletionOutput(
index=0,
text=self.generated_text,
token_ids=self.generated_token_ids.copy(),
finish_reason=self.finish_reason,
)
],
finished=self.is_finished(),
finish_reason=self.finish_reason,
)
"""
Sampling Parameters - Configuration for text generation sampling.
"""
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class SamplingParams:
"""Sampling parameters for text generation."""
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
def __post_init__(self):
if self.stop is None:
self.stop = []
if self.stop_token_ids is None:
self.stop_token_ids = []
def clone(self) -> "SamplingParams":
"""Create a copy of this SamplingParams instance."""
return SamplingParams(
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_tokens=self.max_tokens,
stop=self.stop.copy() if self.stop else None,
stop_token_ids=self.stop_token_ids.copy() if self.stop_token_ids else None,
)
"""
Scheduler - Request scheduling and batch management with Paged Attention KV Cache.
"""
import queue
import janus
import logging
from typing import List, Optional
from infinilm.llm.request import RequestStatus, InferenceRequest
from infinilm.llm.cache_manager import BlockManager
logger = logging.getLogger(__name__)
class SchedulerOutput:
"""Scheduler output containing scheduled requests and execution phase info."""
def __init__(
self,
scheduled_requests: List[InferenceRequest],
is_prefill: bool = False,
):
self.scheduled_requests = scheduled_requests
self.num_requests = len(scheduled_requests)
self.is_prefill = is_prefill
def build_model_inputs(
self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1
):
"""Construct model inputs for prefill or decode phase.
Prefill phase:
- input_ids: Flattened token list (excluding cached tokens)
- position_ids: Position IDs for new tokens in complete sequence
- past_kv_lengths: Number of cached tokens per request
- total_kv_lengths: Total tokens (cached + new) per request
- input_offsets: Start position of each request in flattened array
- block_tables: Padded block_table for each request
- slot_mapping: Token to slot mappings
Decode phase:
- input_ids: Only last generated token per request
- position_ids: Position of last token in complete sequence
- past_kv_lengths: Number of cached tokens per request
- total_kv_lengths: Total sequence length per request
- input_offsets: Offsets for each request
- block_tables: Padded block_table for each request
- slot_mapping: Single slot per request
"""
if not self.scheduled_requests:
raise RuntimeError(
"build_model_inputs called with empty scheduled_requests"
)
tokens = []
seq_lens = []
seq_offsets = [0]
block_tables = []
slot_mapping = []
cached_lens = []
position_ids = []
max_block_table_len = max(
len(req.block_table) for req in self.scheduled_requests
)
current_offset = 0
for req in self.scheduled_requests:
num_cached = req.num_cached_tokens
if self.is_prefill:
# Prefill phase
req_tokens = req.get_input_tokens()
tokens_to_compute = req_tokens[num_cached:]
tokens.extend(tokens_to_compute)
seq_len = len(tokens_to_compute)
seq_lens.append(len(req_tokens))
current_offset += seq_len
seq_offsets.append(current_offset)
slot_mapping.extend(req.slot_mapping)
cached_lens.append(num_cached)
position_ids.extend(range(num_cached, num_cached + seq_len))
else:
# Decode phase
last_token = req.generated_token_ids[-1]
tokens.append(last_token)
seq_lens.append(req.get_total_length())
current_offset += 1
seq_offsets.append(current_offset)
slot_mapping.extend(req.slot_mapping)
cached_lens.append(num_cached)
position_ids.append(req.get_total_length() - 1)
# Pad block_table to same length
padded_block_table = req.block_table + [-1] * (
max_block_table_len - len(req.block_table)
)
block_tables.append(padded_block_table)
return {
"input_ids": tokens,
"position_ids": position_ids,
"past_kv_lengths": cached_lens,
"total_kv_lengths": seq_lens,
"input_offsets": seq_offsets,
"block_tables": block_tables,
"slot_mapping": slot_mapping,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
}
class Scheduler:
"""Request scheduler with integrated BlockManager for KV cache management.
Scheduling logic:
1. Running queue: Check for new blocks needed, update slot_mapping
2. Waiting queue: Try block reuse (prefix caching), allocate new blocks
3. Reference counting: Free blocks when requests complete
"""
def __init__(
self,
max_batch_size: int = 16,
num_blocks: int = 8 * 1024,
block_size: int = 16,
):
self.waiting_queue = janus.Queue()
self.running_queue = janus.Queue()
self.max_batch_size = max_batch_size
self.cache_manager = BlockManager(num_blocks=num_blocks, block_size=block_size)
self.block_size = block_size
def add_request(self, request: InferenceRequest):
if request is not None:
request.status = RequestStatus.WAITING
self.waiting_queue.sync_q.put(request)
def schedule(self) -> Optional[SchedulerOutput]:
"""Schedule and return batch of requests to execute."""
scheduled_requests = []
is_prefill = False
# Process Waiting queue (prefill phase)
while len(scheduled_requests) < self.max_batch_size:
try:
req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty:
break
req_tokens = req.get_input_tokens()
num_required_blocks = req.get_num_blocks_required(self.block_size)
if not self.cache_manager.can_allocate(num_required_blocks):
if not self.cache_manager.try_free_blocks(num_required_blocks):
raise RuntimeError("No available cache blocks")
# Allocate blocks with automatic prefix caching support
req.block_table, req.slot_mapping, req.num_cached_tokens = (
self.cache_manager.allocate_blocks(req_tokens, req.block_table)
)
req.num_blocks = len(req.block_table)
req.status = RequestStatus.RUNNING
scheduled_requests.append(req)
# Return prefill batch if any waiting requests were scheduled
if scheduled_requests:
is_prefill = True
return SchedulerOutput(
scheduled_requests=scheduled_requests,
is_prefill=is_prefill,
)
# Process Running queue (decode phase)
while len(scheduled_requests) < self.max_batch_size:
try:
req = self.running_queue.sync_q.get_nowait()
except queue.Empty:
break
# Decode phase: allocate slot for newly generated token
try:
req.block_table, new_slot = self.cache_manager.append_slot(
req.block_table, req.get_total_length(), req.get_all_token_ids()
)
req.slot_mapping = [new_slot]
req.num_blocks = len(req.block_table)
req.num_cached_tokens = req.get_total_length() - 1
scheduled_requests.append(req)
except RuntimeError as e:
raise RuntimeError("No available cache blocks") from e
# Return decode batch if any running requests were scheduled
if scheduled_requests:
is_prefill = False
return SchedulerOutput(
scheduled_requests=scheduled_requests,
is_prefill=is_prefill,
)
return None
def complete_requests(self, requests: List[InferenceRequest]):
"""Handle completed requests and free their blocks."""
for req in requests:
if req.status in [
RequestStatus.FINISHED,
RequestStatus.CANCELED,
RequestStatus.FAILED,
RequestStatus.TIMEOUT,
]:
if req.block_table:
self.cache_manager.free_blocks(req.block_table)
if req.status == RequestStatus.CANCELED:
logger.info(
f"Request {req.request_id[:8]}... canceled: {req.finish_reason}"
)
elif req.status == RequestStatus.FAILED:
logger.error(
f"Request {req.request_id[:8]}... failed: {req.finish_reason}"
)
elif req.status == RequestStatus.TIMEOUT:
logger.error(
f"Request {req.request_id[:8]}... timed out: {req.finish_reason}"
)
else:
# Still running, put back in running queue
self.running_queue.sync_q.put(req)
def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
"num_blocks": self.cache_manager.num_blocks,
"block_size": self.cache_manager.block_size,
"num_free_blocks": self.cache_manager.get_num_free_blocks(),
"num_req_blocks": len(self.cache_manager.req_block_ids),
"num_used_blocks": len(self.cache_manager.used_block_ids),
}
"""
Inference Server - HTTP API server for LLM inference.
"""
from contextlib import asynccontextmanager
import sys
import time
import json
import uuid
import argparse
import uvicorn
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from infinilm.llm import AsyncLLMEngine, SamplingParams, FinishReason
logger = logging.getLogger(__name__)
DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT = 1000.0
def chunk_json(id_, content=None, role=None, finish_reason=None):
"""Generate JSON chunk for streaming response."""
delta = {}
if content:
delta["content"] = content
if role:
delta["role"] = role
return {
"id": id_,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": "jiuge",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"text": content,
"delta": delta,
"logprobs": None,
"finish_reason": finish_reason,
}
],
}
class InferenceServer:
"""HTTP server for LLM inference."""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: str = "float16",
tensor_parallel_size: int = 1,
max_tokens: int = 4096,
max_batch_size: int = 16,
num_blocks: int = 8 * 1024,
block_size: int = 16,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
host: str = "0.0.0.0",
port: int = 8000,
):
"""Initialize inference server.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_tokens: Default maximum tokens to generate.
max_batch_size: Maximum batch size for inference.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
host: Server host address.
port: Server port number.
"""
self.model_path = model_path
self.device = device
self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.num_blocks = num_blocks
self.block_size = block_size
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.host = host
self.port = port
self.engine: AsyncLLMEngine = None
def start(self):
"""Start the HTTP server."""
app = self._create_app()
logger.info(f"Starting API Server at {self.host}:{self.port}...")
uvicorn.run(app, host=self.host, port=self.port)
logger.info("Inference Server stopped")
def _create_app(self):
"""Create FastAPI application."""
@asynccontextmanager
async def lifespan(app: FastAPI):
self.engine = AsyncLLMEngine(
model_path=self.model_path,
device=self.device,
dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size,
max_batch_size=self.max_batch_size,
max_tokens=self.max_tokens,
num_blocks=self.num_blocks,
block_size=self.block_size,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
)
self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}")
yield
self.engine.stop()
app = FastAPI(lifespan=lifespan)
self._register_routes(app)
return app
def _register_routes(self, app: FastAPI):
"""Register API routes."""
@app.post("/chat/completions")
async def chat_completions(request: Request):
try:
data = await request.json()
logger.debug(f"Received request data: {data}")
except Exception as e:
logger.error(f"Failed to parse request JSON: {e}")
return JSONResponse(content={"error": "Invalid JSON"}, status_code=400)
if not data.get("messages"):
if not data.get("prompt"):
return JSONResponse(
content={"error": "No message provided"}, status_code=400
)
else:
data["messages"] = [{"role": "user", "content": data.get("prompt")}]
stream = data.get("stream", False)
request_id = f"cmpl-{uuid.uuid4().hex}"
if stream:
return StreamingResponse(
self._stream_chat(request_id, data, request),
media_type="text/event-stream",
)
else:
response = await self._chat(request_id, data, request)
if isinstance(response, JSONResponse):
return response
return JSONResponse(content=response)
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": "jiuge",
"object": "model",
"created": int(time.time()),
"owned_by": "infinilm",
}
],
}
def _build_sampling_params(self, data: dict) -> SamplingParams:
"""Build SamplingParams from request data."""
return SamplingParams(
temperature=data.get("temperature", self.temperature),
top_p=data.get("top_p", self.top_p),
top_k=data.get("top_k", self.top_k),
max_tokens=data.get("max_tokens", self.max_tokens),
stop=data.get("stop"),
)
async def _stream_chat(self, request_id: str, data: dict, http_request: Request):
"""Handle streaming chat request."""
req = None
start_time = time.time()
try:
messages = data.get("messages", [])
sampling_params = self._build_sampling_params(data)
req = self.engine.add_chat_request(
messages=messages,
sampling_params=sampling_params,
request_id=request_id,
request_data=data,
http_request=http_request,
)
async for token_output in self.engine.stream_request(
req, timeout=DEFAULT_STREAM_TIMEOUT
):
# Check timeout
if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT:
logger.warning(
f"Request {request_id} timed out after {DEFAULT_REQUEST_TIMEOUT}s"
)
req.mark_timeout()
error_chunk = json.dumps(
chunk_json(
request_id,
content="[Request timeout]",
finish_reason="timeout",
),
ensure_ascii=False,
)
yield f"data: {error_chunk}\n\n"
break
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
# Send token
chunk = json.dumps(
chunk_json(request_id, content=token_output.token_text),
ensure_ascii=False,
)
yield f"data: {chunk}\n\n"
if token_output.finished:
finish_reason = self._convert_finish_reason(
token_output.finish_reason
)
chunk = json.dumps(
chunk_json(request_id, finish_reason=finish_reason),
ensure_ascii=False,
)
yield f"data: {chunk}\n\n"
break
except Exception as e:
logger.error(f"Stream error for {request_id}: {e}", exc_info=True)
if req:
req.mark_failed()
error_chunk = json.dumps(
chunk_json(
request_id, content=f"[Error: {str(e)}]", finish_reason="error"
),
ensure_ascii=False,
)
yield f"data: {error_chunk}\n\n"
finally:
if req and not req.is_finished():
req.mark_canceled()
if req:
await req.close()
yield "data: [DONE]\n\n"
async def _chat(self, request_id: str, data: dict, http_request: Request):
"""Handle non-streaming chat request."""
req = None
start_time = time.time()
try:
messages = data.get("messages", [])
sampling_params = self._build_sampling_params(data)
req = self.engine.add_chat_request(
messages=messages,
sampling_params=sampling_params,
request_id=request_id,
request_data=data,
http_request=http_request,
)
# Collect all generated tokens
output_text = ""
async for token_output in self.engine.stream_request(
req, timeout=DEFAULT_STREAM_TIMEOUT
):
# Check timeout
if time.time() - start_time > DEFAULT_REQUEST_TIMEOUT:
logger.warning(f"Request {request_id} timed out")
req.mark_timeout()
break
# Check client disconnect
if await http_request.is_disconnected():
logger.info(f"Client disconnected for request {request_id}")
req.mark_canceled()
break
output_text += token_output.token_text
if token_output.finished:
break
output_text = output_text.strip()
finish_reason = self._convert_finish_reason(req.finish_reason)
response = chunk_json(
request_id,
content=output_text,
role="assistant",
finish_reason=finish_reason or "stop",
)
return response
except Exception as e:
logger.error(f"Chat error for {request_id}: {e}", exc_info=True)
if req:
req.mark_failed()
return JSONResponse(content={"error": str(e)}, status_code=500)
finally:
if req and not req.is_finished():
req.mark_canceled()
if req:
await req.close()
def _convert_finish_reason(self, reason: FinishReason) -> str:
"""Convert FinishReason enum to string."""
if reason is None:
return None
if reason in (FinishReason.EOS_TOKEN, FinishReason.STOP_STRING):
return "stop"
return reason.value
def setup_logging(log_level: str = "INFO"):
"""Configure logging system with proper formatting and handlers."""
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
logging.basicConfig(
level=getattr(logging, log_level.upper(), logging.INFO),
format=log_format,
datefmt=date_format,
handlers=[
logging.StreamHandler(sys.stdout),
],
force=True,
)
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="InfiniLM Inference Server")
parser.add_argument(
"--model_path", type=str, required=True, help="Path to model directory"
)
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism degree")
parser.add_argument(
"--max_tokens",
type=int,
default=512,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--max_batch_size", type=int, default=8, help="Maximum batch size"
)
parser.add_argument(
"--num_blocks", type=int, default=8 * 1024, help="Number of blocks for KV cache"
)
parser.add_argument(
"--block_size", type=int, default=16, help="Block size for KV cache"
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"],
help="Data type",
)
parser.add_argument(
"--temperature", type=float, default=1.0, help="Sampling temperature"
)
parser.add_argument(
"--top_p", type=float, default=0.8, help="Top-p sampling parameter"
)
parser.add_argument("--top_k", type=int, default=1, help="Top-k sampling parameter")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--cpu", action="store_true", help="Use CPU")
parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU")
parser.add_argument("--metax", action="store_true", help="Use MetaX device")
parser.add_argument("--moore", action="store_true", help="Use Moore device")
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device")
parser.add_argument(
"--log_level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level",
)
return parser.parse_args()
def main():
args = parse_args()
setup_logging(args.log_level)
if args.cpu:
device = "cpu"
elif args.nvidia:
device = "cuda"
elif args.metax:
device = "cuda"
elif args.moore:
device = "moore"
elif args.iluvatar:
device = "cuda"
elif args.cambricon:
device = "mlu"
else:
print(
"Usage: python infinilm.server.inference_server [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] "
"--model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE"
"\n"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
)
sys.exit(1)
server = InferenceServer(
model_path=args.model_path,
device=device,
dtype=args.dtype,
tensor_parallel_size=args.tp,
max_tokens=args.max_tokens,
max_batch_size=args.max_batch_size,
num_blocks=args.num_blocks,
block_size=args.block_size,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
host=args.host,
port=args.port,
)
server.start()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment