Commit 69f18760 authored by wooway777's avatar wooway777
Browse files

issue/204 - support graph in server scripts

parent 693d74d3
......@@ -50,6 +50,7 @@ class EngineConfig:
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
model_path: str
......@@ -63,6 +64,7 @@ class EngineConfig:
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
enable_graph: bool = False
class LLMEngine:
......@@ -74,11 +76,18 @@ class LLMEngine:
# Initialize device and dtype
self._init_device()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
# Initialize model engine
self.model_engine = InferEngine(
model_path=config.model_path,
device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size),
cache_config=cache_config,
enable_graph_compiling=config.enable_graph,
)
# Load model weights
......@@ -92,12 +101,6 @@ class LLMEngine:
)
self._fix_tokenizer_decoder()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
self.model_engine.reset_cache(cache_config)
# Initialize scheduler
self.scheduler = Scheduler(
max_batch_size=config.max_batch_size,
......@@ -113,6 +116,7 @@ class LLMEngine:
logger.info(
f"LLMEngine initialized with model at {config.model_path} "
f"on device {config.device}"
f"enable_graph={config.enable_graph}"
)
def _init_device(self):
......@@ -252,20 +256,22 @@ class LLMEngine:
for stop_str in stop_strings:
if decoded_text.endswith(stop_str):
# Remove the stop string from the end
decoded_text = decoded_text[:-len(stop_str)]
decoded_text = decoded_text[: -len(stop_str)]
req.generated_text = decoded_text
break
holds_back_incomplete_utf8 = (
bool(decoded_text) and decoded_text.endswith("\ufffd")
)
holds_back_incomplete_utf8 = bool(
decoded_text
) and decoded_text.endswith("\ufffd")
# vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING)
finished_now
and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING)
):
token_text = ""
else:
......@@ -275,7 +281,9 @@ class LLMEngine:
req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id):
if req._output_queue is None and self._check_request_finished(
req, token_id
):
req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING:
......@@ -283,7 +291,7 @@ class LLMEngine:
for stop_str in stop_strings:
if req.generated_text.endswith(stop_str):
# Remove the stop string from the end
req.generated_text = req.generated_text[:-len(stop_str)]
req.generated_text = req.generated_text[: -len(stop_str)]
break
# Put output in queue if it exists (for async streaming)
......@@ -362,6 +370,7 @@ class LLM:
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
enable_graph: bool = False,
):
"""Initialize LLM.
......@@ -377,6 +386,7 @@ class LLM:
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
config = EngineConfig(
model_path=model_path,
......@@ -390,6 +400,7 @@ class LLM:
temperature=temperature,
top_p=top_p,
top_k=top_k,
enable_graph=enable_graph,
)
self.engine = LLMEngine(config)
self.config = config
......@@ -506,6 +517,7 @@ class AsyncLLMEngine:
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
enable_graph: bool = False,
):
"""Initialize AsyncLLMEngine.
......@@ -521,6 +533,7 @@ class AsyncLLMEngine:
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
"""
config = EngineConfig(
model_path=model_path,
......@@ -534,6 +547,7 @@ class AsyncLLMEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
enable_graph=enable_graph,
)
self.engine = LLMEngine(config)
self.config = config
......
......@@ -23,7 +23,9 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT = 1000.0
def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"):
def chunk_json(
id_, content=None, role=None, finish_reason=None, model: str = "unknown"
):
"""Generate JSON chunk for streaming response."""
delta = {}
if content:
......@@ -66,6 +68,7 @@ class InferenceServer:
top_k: int = 1,
host: str = "0.0.0.0",
port: int = 8000,
enable_graph: bool = False,
):
"""Initialize inference server.
......@@ -83,6 +86,7 @@ class InferenceServer:
top_k: Default top-k sampling parameter.
host: Server host address.
port: Server port number.
enable_graph: Whether to enable graph compiling.
"""
self.model_path = model_path
# vLLM-like served model id: directory name of model_path
......@@ -99,6 +103,7 @@ class InferenceServer:
self.top_k = top_k
self.host = host
self.port = port
self.enable_graph = enable_graph
self.engine: AsyncLLMEngine = None
......@@ -126,9 +131,11 @@ class InferenceServer:
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
enable_graph=self.enable_graph,
)
self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}")
logger.info(f" enable_graph: {self.enable_graph}")
yield
self.engine.stop()
......@@ -233,7 +240,6 @@ class InferenceServer:
if isinstance(stop, str):
stop = [stop]
return SamplingParams(
temperature=float(pick("temperature", self.temperature)),
top_p=float(pick("top_p", self.top_p)),
......@@ -291,15 +297,15 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = (
eos_token_ids and token_output.token_id in eos_token_ids
)
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
if not is_eos_token and token_output.token_text:
# Send token
chunk = json.dumps(
chunk_json(
request_id, content=token_output.token_text, model=self.model_id
request_id,
content=token_output.token_text,
model=self.model_id,
),
ensure_ascii=False,
)
......@@ -379,9 +385,7 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = (
eos_token_ids and token_output.token_id in eos_token_ids
)
is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
if not is_eos_token:
output_text += token_output.token_text
......@@ -483,6 +487,11 @@ def parse_args():
parser.add_argument("--moore", action="store_true", help="Use Moore device")
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device")
parser.add_argument(
"--enable-graph",
action="store_true",
help="Enable graph compiling",
)
parser.add_argument(
"--log_level",
type=str,
......@@ -518,6 +527,8 @@ def main():
"\n"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"\n"
"Optional: --enable-paged-attn --enable-graph"
)
sys.exit(1)
......@@ -535,6 +546,7 @@ def main():
top_k=args.top_k,
host=args.host,
port=args.port,
enable_graph=args.enable_graph,
)
server.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment