Unverified Commit a4ced800 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #205 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 96ecf490 04c37f3f
......@@ -25,5 +25,11 @@ class AutoConfig:
config_dict["model_type"] == "qwen2" or config_dict["model_type"] == "qwen3"
):
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "minicpm":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "fm9g":
return LlamaConfig(**config_dict)
elif config_dict["model_type"] == "fm9g7b":
return LlamaConfig(**config_dict)
raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.")
......@@ -28,19 +28,28 @@ class InferEngine(_infinilm.InferEngine):
device=None,
distributed_config=DistConfig(1),
cache_config=None,
enable_graph_compiling=False,
):
self.config = AutoConfig.from_pretrained(model_path)
if device is None:
device = infinicore.device()
# super().__init__(
# self.config,
# distributed_config._underlying,
# device._underlying.type,
# cache_config,
# enable_graph_compiling,
# )
super().__init__(
self.config,
model_path,
distributed_config._underlying,
device._underlying.type,
cache_config,
enable_graph_compiling,
)
self.use_cache = False
self.enable_paged_attn = isinstance(cache_config, PagedKVCacheConfig)
......@@ -121,6 +130,22 @@ class InferEngine(_infinilm.InferEngine):
if _measure_and_log_time:
time_measurements = []
block_tables = None
max_blocks_per_batch = 0
if self.enable_paged_attn:
max_blocks_per_batch = (
initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1
) // paged_block_size
block_tables_list = [
range(i * max_blocks_per_batch, (i + 1) * max_blocks_per_batch)
for i in range(batch_size)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
for iter in range(0, generation_config.max_new_tokens):
if _measure_and_log_time:
start_time = time.perf_counter()
......@@ -133,28 +158,28 @@ class InferEngine(_infinilm.InferEngine):
list(range(past_seq_len, past_seq_len + seq_len)) * batch_size,
dtype=infinicore.int64,
)
block_tables_list = [
[
i * batch_size + b
if iter == 0:
slot_mapping_list = []
for b in range(batch_size):
slot_mapping_list.extend(
[
b * max_blocks_per_batch * paged_block_size + i
for i in range(seq_len)
]
)
else:
slot_mapping_list = [
i
for i in range(
(past_seq_len + seq_len + paged_block_size - 1)
// paged_block_size
past_seq_len,
max_blocks_per_batch
* paged_block_size
* initial_batch_size,
max_blocks_per_batch * paged_block_size,
)
]
for b in range(batch_size)
]
slot_mapping_list = [
(((past_seq_len + i) // paged_block_size) * batch_size + b)
* paged_block_size
+ (past_seq_len + i) % paged_block_size
for b in range(batch_size)
for i in range(seq_len)
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
)
slot_mapping = infinicore.from_list(
slot_mapping_list,
dtype=infinicore.int64,
......@@ -168,7 +193,6 @@ class InferEngine(_infinilm.InferEngine):
dtype=infinicore.int64,
)
block_tables = None
slot_mapping = None
past_kv_lengths = infinicore.from_list(
......@@ -205,9 +229,9 @@ class InferEngine(_infinilm.InferEngine):
):
break
input_ids = infinicore.from_list(
[[output_id] for output_id in output_id.to_numpy().tolist()]
)
# start_prepare_time = time.perf_counter()
input_ids = output_id.view([batch_size, 1])
past_seq_len = past_seq_len + seq_len
if _measure_and_log_time:
......
......@@ -18,6 +18,7 @@ from infinilm.llm.llm import (
EngineConfig,
)
from infinilm.llm.scheduler import Scheduler, SchedulerOutput
from infinilm.llm.static_scheduler import StaticScheduler, StaticSchedulerOutput
from infinilm.llm.cache_manager import BlockManager, Block
__all__ = [
......@@ -38,6 +39,8 @@ __all__ = [
# Internal (for advanced use)
"Scheduler",
"SchedulerOutput",
"StaticScheduler",
"StaticSchedulerOutput",
"BlockManager",
"Block",
]
......@@ -23,10 +23,11 @@ from infinilm.llm.request import (
)
from infinilm.llm.sampling_params import SamplingParams
from infinilm.llm.scheduler import Scheduler
from infinilm.llm.static_scheduler import StaticScheduler
from infinilm.distributed import DistConfig
from infinilm.infer_engine import InferEngine
from infinilm.cache.cache import PagedKVCacheConfig
from infinilm.cache.cache import PagedKVCacheConfig, StaticKVCacheConfig
from infinilm.modeling_utils import load_model_state_dict_by_file
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
......@@ -43,26 +44,32 @@ class EngineConfig:
device: Device type string ('cpu', 'cuda', 'mlu', etc.).
dtype: Data type string ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size for inference (only for paged cache).
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
model_path: str
device: str = "cuda"
dtype: str = "float16"
tensor_parallel_size: int = 1
cache_type: str = "paged" # "paged" or "static"
max_batch_size: int = 16
max_tokens: int = 4096
num_blocks: int = 8 * 1024
block_size: int = 16
max_cache_len: int = 4096
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
enable_graph: bool = False
class LLMEngine:
......@@ -79,6 +86,7 @@ class LLMEngine:
model_path=config.model_path,
device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size),
enable_graph_compiling=config.enable_graph,
)
# Load model weights
......@@ -92,18 +100,30 @@ class LLMEngine:
)
self._fix_tokenizer_decoder()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
self.model_engine.reset_cache(cache_config)
# Initialize KV cache based on cache type
if config.cache_type == "static":
cache_config = StaticKVCacheConfig(
max_batch_size=1, max_cache_len=config.max_cache_len
)
self.scheduler = StaticScheduler(max_cache_len=config.max_cache_len)
logger.info(
f"Using Static KV Cache with max_cache_len={config.max_cache_len}"
)
elif config.cache_type == "paged":
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
self.scheduler = Scheduler(
max_batch_size=config.max_batch_size,
num_blocks=config.num_blocks,
block_size=config.block_size,
)
logger.info(f"Using Paged KV Cache with num_blocks={config.num_blocks}")
else:
raise ValueError(f"Unsupported cache_type: {config.cache_type}")
# Initialize scheduler
self.scheduler = Scheduler(
max_batch_size=config.max_batch_size,
num_blocks=config.num_blocks,
block_size=config.block_size,
)
self.model_engine.reset_cache(cache_config)
self.cache_type = config.cache_type
# Get EOS token IDs from model config
self.eos_token_ids = self.model_engine.config.eos_token_id or []
......@@ -113,11 +133,12 @@ class LLMEngine:
logger.info(
f"LLMEngine initialized with model at {config.model_path} "
f"on device {config.device}"
f"enable_graph={config.enable_graph}"
)
def _init_device(self):
"""Initialize infinicore device and dtype."""
supported_devices = ["cpu", "cuda", "mlu", "moore"]
supported_devices = ["cpu", "cuda", "mlu", "musa"]
device_str = self.config.device
if device_str not in supported_devices:
raise ValueError(
......@@ -198,19 +219,21 @@ class LLMEngine:
"""Convert model input dict to infinicore tensors."""
model_input = {}
for key, value in model_input_dict.items():
if key == "input_ids":
model_input[key] = infinicore.from_list([value], dtype=infinicore.int64)
if value is None:
# Skip None values (block_tables/slot_mapping for static cache)
model_input[key] = None
elif key in [
"input_ids",
"position_ids",
"past_kv_lengths",
"total_kv_lengths",
"input_offsets",
"slot_mapping",
"block_tables",
]:
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
elif key == "block_tables":
model_input[key] = infinicore.from_list(value, dtype=infinicore.int64)
else:
# temperature, top_k, top_p, etc.
model_input[key] = value
return model_input
......@@ -221,7 +244,8 @@ class LLMEngine:
sampled_tokens: List[int],
):
"""Update request status after inference step."""
if is_prefill:
# Only reset req blocks for paged cache
if is_prefill and self.cache_type == "paged":
self.scheduler.cache_manager.reset_req_blocks()
for req, token_id in zip(requests, sampled_tokens):
......@@ -252,20 +276,22 @@ class LLMEngine:
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[:-len(stop_str)]
decoded_text = decoded_text[: -len(stop_str)]
req.generated_text = decoded_text
break
holds_back_incomplete_utf8 = (
bool(decoded_text) and decoded_text.endswith("\ufffd")
)
holds_back_incomplete_utf8 = bool(
decoded_text
) and decoded_text.endswith("\ufffd")
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING)
finished_now
and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
token_text = ""
else:
......@@ -275,7 +301,9 @@ class LLMEngine:
req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id):
if req._output_queue is None and self._check_request_finished(
req, token_id
):
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
......@@ -283,9 +311,8 @@ class LLMEngine:
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[:-len(stop_str)]
req.generated_text = req.generated_text[: -len(stop_str)]
break
# Put output in queue if it exists (for async streaming)
if req._output_queue is not None:
output = TokenOutput(
......@@ -355,13 +382,16 @@ class LLM:
device: str = "cuda",
dtype: str = "float16",
tensor_parallel_size: int = 1,
cache_type: str = "paged",
max_batch_size: int = 16,
max_tokens: int = 4096,
num_blocks: int = 8 * 1024,
block_size: int = 16,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
enable_graph: bool = False,
):
"""Initialize LLM.
......@@ -370,26 +400,32 @@ class LLM:
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size (only for paged cache).
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
config = EngineConfig(
model_path=model_path,
device=device,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
cache_type=cache_type,
max_batch_size=max_batch_size,
max_tokens=max_tokens,
num_blocks=num_blocks,
block_size=block_size,
max_cache_len=max_cache_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
enable_graph=enable_graph,
)
self.engine = LLMEngine(config)
self.config = config
......@@ -499,13 +535,16 @@ class AsyncLLMEngine:
device: str = "cuda",
dtype: str = "float16",
tensor_parallel_size: int = 1,
cache_type: str = "paged",
max_batch_size: int = 16,
max_tokens: int = 512,
num_blocks: int = 8 * 1024,
block_size: int = 16,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
enable_graph: bool = False,
):
"""Initialize AsyncLLMEngine.
......@@ -514,26 +553,32 @@ class AsyncLLMEngine:
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
cache_type: Cache type ('paged' or 'static').
max_batch_size: Maximum batch size (only for paged cache).
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
config = EngineConfig(
model_path=model_path,
device=device,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
cache_type=cache_type,
max_batch_size=max_batch_size,
max_tokens=max_tokens,
num_blocks=num_blocks,
block_size=block_size,
max_cache_len=max_cache_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
enable_graph=enable_graph,
)
self.engine = LLMEngine(config)
self.config = config
......
......@@ -103,7 +103,7 @@ class SchedulerOutput:
block_tables.append(padded_block_table)
return {
"input_ids": tokens,
"input_ids": [tokens],
"position_ids": position_ids,
"past_kv_lengths": cached_lens,
"total_kv_lengths": seq_lens,
......@@ -154,6 +154,10 @@ class Scheduler:
req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty:
break
# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
continue
if not self.can_accept_request(req):
self.waiting_queue.sync_q.put(req)
......
"""
Static Scheduler - Single-batch request scheduling for Static KV Cache.
"""
import logging
import queue
import janus
from typing import List, Optional
from infinilm.llm.request import RequestStatus, InferenceRequest, FinishReason
logger = logging.getLogger(__name__)
class StaticSchedulerOutput:
"""Static scheduler output containing single request and execution phase info."""
def __init__(
self,
scheduled_requests: List[InferenceRequest],
is_prefill: bool = False,
):
self.scheduled_requests = scheduled_requests
self.num_requests = len(scheduled_requests)
self.is_prefill = is_prefill
def build_model_inputs(
self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1
):
"""Construct model inputs for prefill or decode phase.
Static cache model inputs:
Prefill phase:
- input_ids: All prompt tokens [1, prompt_length]
- position_ids: [0, 1, 2, ..., prompt_length-1]
- past_kv_lengths: [0] (no cached tokens initially)
- total_kv_lengths: [prompt_length]
Decode phase:
- input_ids: Only the last generated token [1, 1]
- position_ids: [current_position] (position in full sequence)
- past_kv_lengths: [num_cached_tokens]
- total_kv_lengths: [total_tokens]
-
"""
req = self.scheduled_requests[0]
if self.is_prefill:
# Prefill: send all prompt tokens
tokens = req.get_input_tokens()
input_ids = [tokens]
position_ids = [list(range(len(tokens)))]
past_kv_len = 0
total_kv_len = len(tokens)
input_offsets = [0, len(tokens)]
else:
# Decode: send only the last generated token
last_token = req.generated_token_ids[-1]
current_position = req.get_total_length() - 1
input_ids = [[last_token]]
position_ids = [[current_position]]
past_kv_len = current_position
total_kv_len = req.get_total_length()
input_offsets = [0, 1]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_kv_lengths": [past_kv_len],
"total_kv_lengths": [total_kv_len],
"input_offsets": input_offsets,
"block_tables": None,
"slot_mapping": None,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
}
class StaticScheduler:
"""Request scheduler for Static KV Cache with batch_size=1.
Simplified scheduling logic:
- Only handles one request at a time
- No cache block management needed
- Simple waiting queue for incoming requests
"""
def __init__(self, max_cache_len: int = 4096):
self.waiting_queue = janus.Queue()
self.running_request: Optional[InferenceRequest] = None
self.max_cache_len = max_cache_len
def add_request(self, request: InferenceRequest):
if request is not None:
request.status = RequestStatus.WAITING
self.waiting_queue.sync_q.put(request)
def schedule(self) -> Optional[StaticSchedulerOutput]:
"""Schedule and return single request to execute."""
while True:
# Case 1: Continue running request (decode phase)
if self.running_request is not None:
req = self.running_request
if req.is_finished():
self.running_request = None
continue
if req.get_total_length() > self.max_cache_len:
logger.warning(
f"Request {req.request_id} exceeds max_cache_len={self.max_cache_len}, "
"completing request."
)
self.running_request = None
req.mark_failed(FinishReason.LENGTH)
continue
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
# Case 2: Get new request from waiting queue (prefill phase)
try:
req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty:
return None
if req.is_finished():
continue
prompt_len = req.get_prompt_length()
if prompt_len > self.max_cache_len:
logger.error(
f"Request {req.request_id} prompt length {prompt_len} "
f"exceeds max_cache_len={self.max_cache_len}. Request rejected."
)
req.mark_failed(FinishReason.LENGTH)
continue
req.status = RequestStatus.RUNNING
self.running_request = req
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=True)
def complete_requests(self, requests: List[InferenceRequest]):
"""Handle completed requests."""
for req in requests:
if req.is_finished() and req == self.running_request:
self.running_request = None
logger.debug(f"Completed request {req.request_id}")
def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
"max_cache_len": self.max_cache_len,
"running_request": (
self.running_request.request_id if self.running_request else None
),
"waiting_queue_size": self.waiting_queue.sync_q.qsize(),
}
......@@ -75,7 +75,7 @@ def load_state_dict(
)
for k in f.keys():
state_dict[k] = f.get_tensor(k).to(device=device, dtype=dtype)
state_dict[k] = f.get_tensor(k).to(device=device)
return state_dict
......@@ -155,7 +155,6 @@ def load_model_state_dict_by_file(
model_param_infini = {}
for key in model_param.keys():
model_param_infini[key] = infinicore.from_torch(model_param[key])
model.load_state_dict(model_param_infini, strict=False)
infinicore.sync_device()
......@@ -168,7 +167,6 @@ def load_model_state_dict_by_file(
model_param_infini[key] = infinicore.from_torch(
model_params[key].to(dtype=torch_dtype)
)
already_loaded_keys.append(key)
model.load_state_dict(model_param_infini, strict=True)
......
......@@ -23,7 +23,9 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT = 1000.0
def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"):
def chunk_json(
id_, content=None, role=None, finish_reason=None, model: str = "unknown"
):
"""Generate JSON chunk for streaming response."""
delta = {}
if content:
......@@ -48,6 +50,42 @@ def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "u
}
def completion_json(
id_,
content,
role="assistant",
finish_reason="stop",
model: str = "unknown",
prompt_tokens: int = 0,
completion_tokens: int = 0,
total_tokens: int = 0,
):
"""Generate JSON response for non-streaming completion."""
return {
"id": id_,
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"system_fingerprint": None,
"choices": [
{
"index": 0,
"message": {
"role": role,
"content": content,
},
"logprobs": None,
"finish_reason": finish_reason,
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
},
}
class InferenceServer:
"""HTTP server for LLM inference."""
......@@ -57,15 +95,18 @@ class InferenceServer:
device: str = "cuda",
dtype: str = "float16",
tensor_parallel_size: int = 1,
cache_type: str = "paged",
max_tokens: int = 4096,
max_batch_size: int = 16,
num_blocks: int = 8 * 1024,
block_size: int = 16,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
host: str = "0.0.0.0",
port: int = 8000,
enable_graph: bool = False,
):
"""Initialize inference server.
......@@ -74,15 +115,18 @@ class InferenceServer:
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
cache_type: Cache type ('paged' or 'static').
max_tokens: Default maximum tokens to generate.
max_batch_size: Maximum batch size for inference.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
max_batch_size: Maximum batch size for inference (only for paged cache).
num_blocks: Number of KV cache blocks (only for paged cache).
block_size: Size of each KV cache block (only for paged cache).
max_cache_len: Maximum sequence length (only for static cache).
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
host: Server host address.
port: Server port number.
enable_graph: Whether to enable graph compiling.
"""
self.model_path = model_path
# vLLM-like served model id: directory name of model_path
......@@ -90,15 +134,18 @@ class InferenceServer:
self.device = device
self.dtype = dtype
self.tensor_parallel_size = tensor_parallel_size
self.cache_type = cache_type
self.max_tokens = max_tokens
self.max_batch_size = max_batch_size
self.num_blocks = num_blocks
self.block_size = block_size
self.max_cache_len = max_cache_len
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.host = host
self.port = port
self.enable_graph = enable_graph
self.engine: AsyncLLMEngine = None
......@@ -119,16 +166,20 @@ class InferenceServer:
device=self.device,
dtype=self.dtype,
tensor_parallel_size=self.tensor_parallel_size,
cache_type=self.cache_type,
max_batch_size=self.max_batch_size,
max_tokens=self.max_tokens,
num_blocks=self.num_blocks,
block_size=self.block_size,
max_cache_len=self.max_cache_len,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
enable_graph=self.enable_graph,
)
self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}")
logger.info(f" enable_graph: {self.enable_graph}")
yield
self.engine.stop()
......@@ -159,6 +210,9 @@ class InferenceServer:
else:
data["messages"] = [{"role": "user", "content": data.get("prompt")}]
# Normalize messages to handle multimodal content (list format)
data["messages"] = self._normalize_messages(data.get("messages", []))
stream = data.get("stream", False)
request_id = f"cmpl-{uuid.uuid4().hex}"
......@@ -206,6 +260,39 @@ class InferenceServer:
async def list_models_legacy():
return _models_payload()
def _normalize_messages(self, messages: list) -> list:
"""Normalize messages to handle multimodal content (list format).
Converts content from list format [{"type": "text", "text": "..."}]
to string format for chat template compatibility.
"""
normalized = []
for msg in messages:
if not isinstance(msg, dict):
normalized.append(msg)
continue
content = msg.get("content")
if isinstance(content, list):
# Extract text from multimodal content list
text_parts = []
for part in content:
if isinstance(part, dict):
if part.get("type") == "text" and "text" in part:
text_parts.append(part["text"])
elif isinstance(part, str):
text_parts.append(part)
elif isinstance(part, str):
text_parts.append(part)
# Join all text parts
normalized_msg = msg.copy()
normalized_msg["content"] = "".join(text_parts) if text_parts else ""
normalized.append(normalized_msg)
else:
normalized.append(msg)
return normalized
def _build_sampling_params(self, data: dict) -> SamplingParams:
"""Build SamplingParams from request data."""
# Support both:
......@@ -233,7 +320,6 @@ class InferenceServer:
if isinstance(stop, str):
stop = [stop]
return SamplingParams(
temperature=float(pick("temperature", self.temperature)),
top_p=float(pick("top_p", self.top_p)),
......@@ -291,15 +377,15 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = (
eos_token_ids and token_output.token_id in eos_token_ids
)
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
if not is_eos_token and token_output.token_text:
# Send token
chunk = json.dumps(
chunk_json(
request_id, content=token_output.token_text, model=self.model_id
request_id,
content=token_output.token_text,
model=self.model_id,
),
ensure_ascii=False,
)
......@@ -379,9 +465,7 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = (
eos_token_ids and token_output.token_id in eos_token_ids
)
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
if not is_eos_token:
output_text += token_output.token_text
......@@ -392,12 +476,15 @@ class InferenceServer:
output_text = output_text.strip()
finish_reason = self._convert_finish_reason(req.finish_reason)
response = chunk_json(
response = completion_json(
request_id,
content=output_text,
role="assistant",
finish_reason=finish_reason or "stop",
model=self.model_id,
prompt_tokens=req.get_prompt_length(),
completion_tokens=req.get_num_generated_tokens(),
total_tokens=req.get_total_length(),
)
return response
......@@ -446,6 +533,13 @@ def parse_args():
"--model_path", type=str, required=True, help="Path to model directory"
)
parser.add_argument("--tp", type=int, default=1, help="Tensor parallelism degree")
parser.add_argument(
"--cache_type",
type=str,
default="paged",
choices=["paged", "static"],
help="Cache type: paged or static",
)
parser.add_argument(
"--max_tokens",
type=int,
......@@ -453,13 +547,28 @@ def parse_args():
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--max_batch_size", type=int, default=8, help="Maximum batch size"
"--max_batch_size",
type=int,
default=8,
help="Maximum batch size (paged cache only)",
)
parser.add_argument(
"--num_blocks",
type=int,
default=8 * 1024,
help="Number of blocks for KV cache (paged cache only)",
)
parser.add_argument(
"--num_blocks", type=int, default=8 * 1024, help="Number of blocks for KV cache"
"--block_size",
type=int,
default=16,
help="Block size for KV cache (paged cache only)",
)
parser.add_argument(
"--block_size", type=int, default=16, help="Block size for KV cache"
"--max_cache_len",
type=int,
default=4096,
help="Maximum sequence length (static cache only)",
)
parser.add_argument(
"--dtype",
......@@ -479,10 +588,17 @@ def parse_args():
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--cpu", action="store_true", help="Use CPU")
parser.add_argument("--nvidia", action="store_true", help="Use NVIDIA GPU")
parser.add_argument("--qy", action="store_true", help="Use QY GPU")
parser.add_argument("--metax", action="store_true", help="Use MetaX device")
parser.add_argument("--moore", action="store_true", help="Use Moore device")
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device")
parser.add_argument("--ali", action="store_true", help="Use Ali PPU device")
parser.add_argument(
"--enable-graph",
action="store_true",
help="Enable graph compiling",
)
parser.add_argument(
"--log_level",
type=str,
......@@ -503,21 +619,27 @@ def main():
device = "cpu"
elif args.nvidia:
device = "cuda"
elif args.qy:
device = "cuda"
elif args.metax:
device = "cuda"
elif args.moore:
device = "moore"
device = "musa"
elif args.iluvatar:
device = "cuda"
elif args.cambricon:
device = "mlu"
elif args.ali:
device = "cuda"
else:
print(
"Usage: python infinilm.server.inference_server [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] "
"Usage: python infinilm.server.inference_server [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali] "
"--model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE"
"\n"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"\n"
"Optional: --enable-paged-attn --enable-graph"
)
sys.exit(1)
......@@ -526,15 +648,18 @@ def main():
device=device,
dtype=args.dtype,
tensor_parallel_size=args.tp,
cache_type=args.cache_type,
max_tokens=args.max_tokens,
max_batch_size=args.max_batch_size,
num_blocks=args.num_blocks,
block_size=args.block_size,
max_cache_len=args.max_cache_len,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
host=args.host,
port=args.port,
enable_graph=args.enable_graph,
)
server.start()
......
......@@ -860,9 +860,11 @@ def test():
device_type = DeviceType.DEVICE_TYPE_KUNLUN
elif sys.argv[1] == "--hygon":
device_type = DeviceType.DEVICE_TYPE_HYGON
elif sys.argv[1] == "--ali":
device_type = DeviceType.DEVICE_TYPE_ALI
else:
print(
"Usage: python jiuge.py [--cpu | --nvidia| --qy| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> [n_device] [--verbose]"
"Usage: python jiuge.py [--cpu | --nvidia| --qy| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon | --ali] <path/to/model_dir> [n_device] [--verbose]"
)
sys.exit(1)
......
......@@ -37,6 +37,7 @@ class DeviceType(ctypes.c_int):
DEVICE_TYPE_KUNLUN = 7
DEVICE_TYPE_HYGON = 8
DEVICE_TYPE_QY = 9
DEVICE_TYPE_ALI = 10
class KVCacheCStruct(ctypes.Structure):
......
......@@ -81,7 +81,6 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) {
__C void
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
std::string name_str(name);
// std::cout << "Loading weight: " << name_str << std::endl;
auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_);
weights->load(name_str, data);
}
import sys
import os
import argparse
import time
import re
import csv
......@@ -8,9 +7,8 @@ import numpy as np
import infinicore
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.cache import StaticKVCacheConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
from datasets import load_dataset, Dataset
from abc import ABC, abstractmethod
......@@ -56,6 +54,7 @@ class InfiniLMBenchmark(BaseBenchmark):
ndev=1,
backend="cpp",
benchmark="ceval",
enable_paged_attn=False,
):
import transformers
......@@ -124,7 +123,9 @@ class InfiniLMBenchmark(BaseBenchmark):
model_dir_path,
device=self.device,
distributed_config=DistConfig(ndev),
cache_config=StaticKVCacheConfig(),
cache_config=(
PagedKVCacheConfig(128) if enable_paged_attn else StaticKVCacheConfig()
),
)
# Enable KV cache for generation
......@@ -673,6 +674,7 @@ def test():
max_new_tokens = 500
output_csv = None
cache_dir = None
enable_paged_attn = False
i = 3
while i < len(sys.argv):
......@@ -703,6 +705,9 @@ def test():
elif sys.argv[i] == "--cache_dir" and i + 1 < len(sys.argv):
cache_dir = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--enable_paged_attn":
enable_paged_attn = True
i += 1
else:
i += 1
......@@ -757,16 +762,13 @@ def test():
subject_list = ["all"]
# Create model based on backend (create once, reuse for all subjects)
if backend != "010":
if backend == "torch":
model = TorchBenchmark(model_path, device_type_str, benchmark)
else:
model = InfiniLMBenchmark(
model_path, device_type_str, ndev, backend, benchmark
)
if backend == "torch":
model = TorchBenchmark(model_path, device_type_str, benchmark)
else:
print(f"test 010 backend by scripts/test_ceval.py")
exit(0)
model = InfiniLMBenchmark(
model_path, device_type_str, ndev, backend, benchmark, enable_paged_attn
)
# Define helper functions for loading datasets
if benchmark == "ceval":
......
Subproject commit 5ed07097faa6c50199c4a3b66e5ed37d4fbfccc2
......@@ -6,6 +6,7 @@ set_toolchains("gcc")
-- Add spdlog from third_party directory
add_includedirs("third_party/spdlog/include")
add_includedirs("third_party/json/single_include/")
target("infinicore_infer")
set_kind("shared")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment