Unverified Commit 5623826f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve logging and rename the health check endpoint name (#1180)

parent 83e23c69
...@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers. ...@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
import dataclasses import dataclasses
import logging import logging
import multiprocessing import multiprocessing
import os
from enum import Enum, auto from enum import Enum, auto
import numpy as np import numpy as np
......
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
import logging import logging
import multiprocessing import multiprocessing
import os
from typing import List from typing import List
import zmq import zmq
......
...@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass @dataclasses.dataclass
class DecodeStatus: class DecodeStatus:
"""Store the status of incremental decoding."""
vid: int vid: int
decoded_text: str decoded_text: str
decode_ids: List[int] decode_ids: List[int]
...@@ -47,6 +49,8 @@ class DecodeStatus: ...@@ -47,6 +49,8 @@ class DecodeStatus:
class DetokenizerManager: class DetokenizerManager:
"""DetokenizerManager is a process that detokenizes the token ids."""
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
......
...@@ -62,12 +62,16 @@ logger = logging.getLogger(__name__) ...@@ -62,12 +62,16 @@ logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class ReqState: class ReqState:
"""Store the state a request."""
out_list: List out_list: List
finished: bool finished: bool
event: asyncio.Event event: asyncio.Event
class TokenizerManager: class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
...@@ -481,11 +485,7 @@ class TokenizerManager: ...@@ -481,11 +485,7 @@ class TokenizerManager:
# Log requests # Log requests
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
if obj.text is None: logger.info(f"in={obj}, out={out}")
in_obj = {"input_ids": obj.input_ids}
else:
in_obj = {"text": obj.text}
logger.info(f"in={in_obj}, out={out}")
state.out_list = [] state.out_list = []
if state.finished: if state.finished:
......
...@@ -92,11 +92,15 @@ app = FastAPI() ...@@ -92,11 +92,15 @@ app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
@app.get("/v1/health") @app.get("/health")
async def health(request: Request) -> Response: async def health() -> Response:
""" """Check the health of the http server."""
Generate 1 token to verify the health of the inference service. return Response(status_code=200)
"""
@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
gri = GenerateReqInput( gri = GenerateReqInput(
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7} text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
) )
...@@ -109,12 +113,6 @@ async def health(request: Request) -> Response: ...@@ -109,12 +113,6 @@ async def health(request: Request) -> Response:
return Response(status_code=503) return Response(status_code=503)
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/get_model_info") @app.get("/get_model_info")
async def get_model_info(): async def get_model_info():
result = { result = {
......
...@@ -422,13 +422,13 @@ class ServerArgs: ...@@ -422,13 +422,13 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-mla", "--enable-mla",
action="store_true", action="store_true",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2", help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
) )
parser.add_argument( parser.add_argument(
"--attention-reduce-in-fp32", "--attention-reduce-in-fp32",
action="store_true", action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels", "This only affects Triton attention kernels.",
) )
parser.add_argument( parser.add_argument(
"--efficient-weight-load", "--efficient-weight-load",
...@@ -452,15 +452,6 @@ class ServerArgs: ...@@ -452,15 +452,6 @@ class ServerArgs:
def url(self): def url(self):
return f"http://{self.host}:{self.port}" return f"http://{self.host}:{self.port}"
def print_mode_args(self):
return (
f"disable_flashinfer={self.disable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
)
def check_server_args(self): def check_server_args(self):
assert ( assert (
self.tp_size % self.nnodes == 0 self.tp_size % self.nnodes == 0
...@@ -469,7 +460,7 @@ class ServerArgs: ...@@ -469,7 +460,7 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
if "gemma-2" in self.model_path.lower(): if "gemma-2" in self.model_path.lower():
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.") logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False self.disable_flashinfer = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment