Commit 69f18760 authored by wooway777's avatar wooway777
Browse files

issue/204 - support graph in server scripts

parent 693d74d3
...@@ -50,6 +50,7 @@ class EngineConfig: ...@@ -50,6 +50,7 @@ class EngineConfig:
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
""" """
model_path: str model_path: str
...@@ -63,6 +64,7 @@ class EngineConfig: ...@@ -63,6 +64,7 @@ class EngineConfig:
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 0.8 top_p: float = 0.8
top_k: int = 1 top_k: int = 1
enable_graph: bool = False
class LLMEngine: class LLMEngine:
...@@ -74,11 +76,18 @@ class LLMEngine: ...@@ -74,11 +76,18 @@ class LLMEngine:
# Initialize device and dtype # Initialize device and dtype
self._init_device() self._init_device()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
# Initialize model engine # Initialize model engine
self.model_engine = InferEngine( self.model_engine = InferEngine(
model_path=config.model_path, model_path=config.model_path,
device=self.device, device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size), distributed_config=DistConfig(config.tensor_parallel_size),
cache_config=cache_config,
enable_graph_compiling=config.enable_graph,
) )
# Load model weights # Load model weights
...@@ -92,12 +101,6 @@ class LLMEngine: ...@@ -92,12 +101,6 @@ class LLMEngine:
) )
self._fix_tokenizer_decoder() self._fix_tokenizer_decoder()
# Initialize KV cache
cache_config = PagedKVCacheConfig(
num_blocks=config.num_blocks, block_size=config.block_size
)
self.model_engine.reset_cache(cache_config)
# Initialize scheduler # Initialize scheduler
self.scheduler = Scheduler( self.scheduler = Scheduler(
max_batch_size=config.max_batch_size, max_batch_size=config.max_batch_size,
...@@ -113,6 +116,7 @@ class LLMEngine: ...@@ -113,6 +116,7 @@ class LLMEngine:
logger.info( logger.info(
f"LLMEngine initialized with model at {config.model_path} " f"LLMEngine initialized with model at {config.model_path} "
f"on device {config.device}" f"on device {config.device}"
f"enable_graph={config.enable_graph}"
) )
def _init_device(self): def _init_device(self):
...@@ -252,20 +256,22 @@ class LLMEngine: ...@@ -252,20 +256,22 @@ class LLMEngine:
for stop_str in stop_strings: for stop_str in stop_strings:
if decoded_text.endswith(stop_str): if decoded_text.endswith(stop_str):
# Remove the stop string from the end # Remove the stop string from the end
decoded_text = decoded_text[:-len(stop_str)] decoded_text = decoded_text[: -len(stop_str)]
req.generated_text = decoded_text req.generated_text = decoded_text
break break
holds_back_incomplete_utf8 = ( holds_back_incomplete_utf8 = bool(
bool(decoded_text) and decoded_text.endswith("\ufffd") decoded_text
) ) and decoded_text.endswith("\ufffd")
# vLLM-style: hold back only if we are not on the final chunk. # vLLM-style: hold back only if we are not on the final chunk.
# Suppress output when finish reason is LENGTH or STOP_STRING. # Suppress output when finish reason is LENGTH or STOP_STRING.
# Root cause fix: When STOP_STRING is detected, we suppress output for the token # Root cause fix: When STOP_STRING is detected, we suppress output for the token
# that completes the stop string, preventing additional tokens from being output. # that completes the stop string, preventing additional tokens from being output.
if (holds_back_incomplete_utf8 and not finished_now) or ( if (holds_back_incomplete_utf8 and not finished_now) or (
finished_now and req.finish_reason in (FinishReason.LENGTH, FinishReason.STOP_STRING) finished_now
and req.finish_reason
in (FinishReason.LENGTH, FinishReason.STOP_STRING)
): ):
token_text = "" token_text = ""
else: else:
...@@ -275,7 +281,9 @@ class LLMEngine: ...@@ -275,7 +281,9 @@ class LLMEngine:
req._stream_last_yielded_length = len(decoded_text) req._stream_last_yielded_length = len(decoded_text)
# For non-streaming, finish checks happen here. # For non-streaming, finish checks happen here.
if req._output_queue is None and self._check_request_finished(req, token_id): if req._output_queue is None and self._check_request_finished(
req, token_id
):
req.mark_finished(req.finish_reason) req.mark_finished(req.finish_reason)
# Remove stop string from generated_text if STOP_STRING finish reason # Remove stop string from generated_text if STOP_STRING finish reason
if req.finish_reason == FinishReason.STOP_STRING: if req.finish_reason == FinishReason.STOP_STRING:
...@@ -283,7 +291,7 @@ class LLMEngine: ...@@ -283,7 +291,7 @@ class LLMEngine:
for stop_str in stop_strings: for stop_str in stop_strings:
if req.generated_text.endswith(stop_str): if req.generated_text.endswith(stop_str):
# Remove the stop string from the end # Remove the stop string from the end
req.generated_text = req.generated_text[:-len(stop_str)] req.generated_text = req.generated_text[: -len(stop_str)]
break break
# Put output in queue if it exists (for async streaming) # Put output in queue if it exists (for async streaming)
...@@ -362,6 +370,7 @@ class LLM: ...@@ -362,6 +370,7 @@ class LLM:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
enable_graph: bool = False,
): ):
"""Initialize LLM. """Initialize LLM.
...@@ -377,6 +386,7 @@ class LLM: ...@@ -377,6 +386,7 @@ class LLM:
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
""" """
config = EngineConfig( config = EngineConfig(
model_path=model_path, model_path=model_path,
...@@ -390,6 +400,7 @@ class LLM: ...@@ -390,6 +400,7 @@ class LLM:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
enable_graph=enable_graph,
) )
self.engine = LLMEngine(config) self.engine = LLMEngine(config)
self.config = config self.config = config
...@@ -506,6 +517,7 @@ class AsyncLLMEngine: ...@@ -506,6 +517,7 @@ class AsyncLLMEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
enable_graph: bool = False,
): ):
"""Initialize AsyncLLMEngine. """Initialize AsyncLLMEngine.
...@@ -521,6 +533,7 @@ class AsyncLLMEngine: ...@@ -521,6 +533,7 @@ class AsyncLLMEngine:
temperature: Default sampling temperature. temperature: Default sampling temperature.
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
""" """
config = EngineConfig( config = EngineConfig(
model_path=model_path, model_path=model_path,
...@@ -534,6 +547,7 @@ class AsyncLLMEngine: ...@@ -534,6 +547,7 @@ class AsyncLLMEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
enable_graph=enable_graph,
) )
self.engine = LLMEngine(config) self.engine = LLMEngine(config)
self.config = config self.config = config
......
...@@ -23,7 +23,9 @@ DEFAULT_STREAM_TIMEOUT = 100.0 ...@@ -23,7 +23,9 @@ DEFAULT_STREAM_TIMEOUT = 100.0
DEFAULT_REQUEST_TIMEOUT = 1000.0 DEFAULT_REQUEST_TIMEOUT = 1000.0
def chunk_json(id_, content=None, role=None, finish_reason=None, model: str = "unknown"): def chunk_json(
id_, content=None, role=None, finish_reason=None, model: str = "unknown"
):
"""Generate JSON chunk for streaming response.""" """Generate JSON chunk for streaming response."""
delta = {} delta = {}
if content: if content:
...@@ -66,6 +68,7 @@ class InferenceServer: ...@@ -66,6 +68,7 @@ class InferenceServer:
top_k: int = 1, top_k: int = 1,
host: str = "0.0.0.0", host: str = "0.0.0.0",
port: int = 8000, port: int = 8000,
enable_graph: bool = False,
): ):
"""Initialize inference server. """Initialize inference server.
...@@ -83,6 +86,7 @@ class InferenceServer: ...@@ -83,6 +86,7 @@ class InferenceServer:
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
host: Server host address. host: Server host address.
port: Server port number. port: Server port number.
enable_graph: Whether to enable graph compiling.
""" """
self.model_path = model_path self.model_path = model_path
# vLLM-like served model id: directory name of model_path # vLLM-like served model id: directory name of model_path
...@@ -99,6 +103,7 @@ class InferenceServer: ...@@ -99,6 +103,7 @@ class InferenceServer:
self.top_k = top_k self.top_k = top_k
self.host = host self.host = host
self.port = port self.port = port
self.enable_graph = enable_graph
self.engine: AsyncLLMEngine = None self.engine: AsyncLLMEngine = None
...@@ -126,9 +131,11 @@ class InferenceServer: ...@@ -126,9 +131,11 @@ class InferenceServer:
temperature=self.temperature, temperature=self.temperature,
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
enable_graph=self.enable_graph,
) )
self.engine.start() self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f"Engine initialized with model at {self.model_path}")
logger.info(f" enable_graph: {self.enable_graph}")
yield yield
self.engine.stop() self.engine.stop()
...@@ -233,7 +240,6 @@ class InferenceServer: ...@@ -233,7 +240,6 @@ class InferenceServer:
if isinstance(stop, str): if isinstance(stop, str):
stop = [stop] stop = [stop]
return SamplingParams( return SamplingParams(
temperature=float(pick("temperature", self.temperature)), temperature=float(pick("temperature", self.temperature)),
top_p=float(pick("top_p", self.top_p)), top_p=float(pick("top_p", self.top_p)),
...@@ -291,15 +297,15 @@ class InferenceServer: ...@@ -291,15 +297,15 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility # Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids # Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = ( is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
eos_token_ids and token_output.token_id in eos_token_ids
)
if not is_eos_token and token_output.token_text: if not is_eos_token and token_output.token_text:
# Send token # Send token
chunk = json.dumps( chunk = json.dumps(
chunk_json( chunk_json(
request_id, content=token_output.token_text, model=self.model_id request_id,
content=token_output.token_text,
model=self.model_id,
), ),
ensure_ascii=False, ensure_ascii=False,
) )
...@@ -379,9 +385,7 @@ class InferenceServer: ...@@ -379,9 +385,7 @@ class InferenceServer:
# Skip EOS token text for OpenAI API compatibility # Skip EOS token text for OpenAI API compatibility
# Check if this token is an EOS token by comparing token_id with eos_token_ids # Check if this token is an EOS token by comparing token_id with eos_token_ids
eos_token_ids = self.engine.engine.eos_token_ids eos_token_ids = self.engine.engine.eos_token_ids
is_eos_token = ( is_eos_token = eos_token_ids and token_output.token_id in eos_token_ids
eos_token_ids and token_output.token_id in eos_token_ids
)
if not is_eos_token: if not is_eos_token:
output_text += token_output.token_text output_text += token_output.token_text
...@@ -483,6 +487,11 @@ def parse_args(): ...@@ -483,6 +487,11 @@ def parse_args():
parser.add_argument("--moore", action="store_true", help="Use Moore device") parser.add_argument("--moore", action="store_true", help="Use Moore device")
parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device") parser.add_argument("--iluvatar", action="store_true", help="Use Iluvatar device")
parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device") parser.add_argument("--cambricon", action="store_true", help="Use Cambricon device")
parser.add_argument(
"--enable-graph",
action="store_true",
help="Enable graph compiling",
)
parser.add_argument( parser.add_argument(
"--log_level", "--log_level",
type=str, type=str,
...@@ -518,6 +527,8 @@ def main(): ...@@ -518,6 +527,8 @@ def main():
"\n" "\n"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ " "Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1" "--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"\n"
"Optional: --enable-paged-attn --enable-graph"
) )
sys.exit(1) sys.exit(1)
...@@ -535,6 +546,7 @@ def main(): ...@@ -535,6 +546,7 @@ def main():
top_k=args.top_k, top_k=args.top_k,
host=args.host, host=args.host,
port=args.port, port=args.port,
enable_graph=args.enable_graph,
) )
server.start() server.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment