Unverified Commit 5623826f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve logging and rename the health check endpoint name (#1180)

parent 83e23c69
......@@ -21,7 +21,6 @@ Each data parallel worker can manage multiple tensor parallel workers.
import dataclasses
import logging
import multiprocessing
import os
from enum import Enum, auto
import numpy as np
......
......@@ -17,7 +17,6 @@ limitations under the License.
import logging
import multiprocessing
import os
from typing import List
import zmq
......
......@@ -39,6 +39,8 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@dataclasses.dataclass
class DecodeStatus:
"""Store the status of incremental decoding."""
vid: int
decoded_text: str
decode_ids: List[int]
......@@ -47,6 +49,8 @@ class DecodeStatus:
class DetokenizerManager:
"""DetokenizerManager is a process that detokenizes the token ids."""
def __init__(
self,
server_args: ServerArgs,
......
......@@ -62,12 +62,16 @@ logger = logging.getLogger(__name__)
@dataclasses.dataclass
class ReqState:
"""Store the state a request."""
out_list: List
finished: bool
event: asyncio.Event
class TokenizerManager:
"""TokenizerManager is a process that tokenizes the text."""
def __init__(
self,
server_args: ServerArgs,
......@@ -481,11 +485,7 @@ class TokenizerManager:
# Log requests
if self.server_args.log_requests and state.finished:
if obj.text is None:
in_obj = {"input_ids": obj.input_ids}
else:
in_obj = {"text": obj.text}
logger.info(f"in={in_obj}, out={out}")
logger.info(f"in={obj}, out={out}")
state.out_list = []
if state.finished:
......
......@@ -92,11 +92,15 @@ app = FastAPI()
tokenizer_manager = None
@app.get("/v1/health")
async def health(request: Request) -> Response:
"""
Generate 1 token to verify the health of the inference service.
"""
@app.get("/health")
async def health() -> Response:
"""Check the health of the http server."""
return Response(status_code=200)
@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
gri = GenerateReqInput(
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
)
......@@ -109,12 +113,6 @@ async def health(request: Request) -> Response:
return Response(status_code=503)
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/get_model_info")
async def get_model_info():
result = {
......
......@@ -422,13 +422,13 @@ class ServerArgs:
parser.add_argument(
"--enable-mla",
action="store_true",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels",
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--efficient-weight-load",
......@@ -452,15 +452,6 @@ class ServerArgs:
def url(self):
return f"http://{self.host}:{self.port}"
def print_mode_args(self):
return (
f"disable_flashinfer={self.disable_flashinfer}, "
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
f"disable_radix_cache={self.disable_radix_cache}, "
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
f"disable_disk_cache={self.disable_disk_cache}, "
)
def check_server_args(self):
assert (
self.tp_size % self.nnodes == 0
......@@ -469,7 +460,7 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
if "gemma-2" in self.model_path.lower():
logger.info(f"When using sliding window in gemma-2, turn on flashinfer.")
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment