Unverified Commit 62b3812b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Time cost utils (#355)

parent 550a4f78
...@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor ...@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams from sglang.lang.ir import SglSamplingParams
try: try:
import tiktoken
import openai import openai
import tiktoken
except ImportError as e: except ImportError as e:
openai = tiktoken = e openai = tiktoken = e
......
...@@ -7,6 +7,7 @@ class FSMCache(BaseCache): ...@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
super().__init__(enable=enable) super().__init__(enable=enable)
from importlib.metadata import version from importlib.metadata import version
if version("outlines") >= "0.0.35": if version("outlines") >= "0.0.35":
from transformers import AutoTokenizer from transformers import AutoTokenizer
......
...@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import ( ...@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import (
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import handle_port_init from sglang.srt.utils import enable_show_time_cost, handle_port_init
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
...@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager global tokenizer_manager
global chat_template_name global chat_template_name
# start show time thread
if server_args.show_time_cost:
enable_show_time_cost()
# disable disk cache if needed # disable disk cache if needed
if server_args.disable_disk_cache: if server_args.disable_disk_cache:
disable_cache() disable_cache()
......
...@@ -26,13 +26,14 @@ class ServerArgs: ...@@ -26,13 +26,14 @@ class ServerArgs:
disable_log_stats: bool = False disable_log_stats: bool = False
log_stats_interval: int = 10 log_stats_interval: int = 10
log_level: str = "info" log_level: str = "info"
api_key: str = ""
show_time_cost: bool = False
# optional modes # optional modes
disable_radix_cache: bool = False disable_radix_cache: bool = False
enable_flashinfer: bool = False enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
api_key: str = ""
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -181,6 +182,18 @@ class ServerArgs: ...@@ -181,6 +182,18 @@ class ServerArgs:
default=ServerArgs.log_stats_interval, default=ServerArgs.log_stats_interval,
help="Log stats interval in second.", help="Log stats interval in second.",
) )
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
parser.add_argument(
"--show-time-cost",
action="store_true",
help="Show time cost of custom marks",
)
# optional modes # optional modes
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
...@@ -202,12 +215,6 @@ class ServerArgs: ...@@ -202,12 +215,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -11,48 +11,56 @@ from typing import List, Optional ...@@ -11,48 +11,56 @@ from typing import List, Optional
import numpy as np import numpy as np
import requests import requests
import torch import torch
import torch.distributed as dist
is_show_cost_time = False show_time_cost = False
time_infos = {}
def mark_cost_time(func_name): def enable_show_time_cost():
def inner_func(func): global show_time_cost
def time_func(*args, **kwargs): show_time_cost = True
if dist.get_rank() in [0, 1] and is_show_cost_time:
torch.cuda.synchronize()
start_time = time.time()
ans = func(*args, **kwargs)
torch.cuda.synchronize()
print(func_name, "cost time:", (time.time() - start_time) * 1000)
return ans
else:
torch.cuda.synchronize()
ans = func(*args, **kwargs)
torch.cuda.synchronize()
return ans
return time_func
return inner_func class TimeInfo:
def __init__(self, name, interval=0.1, color=0, indent=0):
self.name = name
self.interval = interval
self.color = color
self.indent = indent
self.acc_time = 0
self.last_acc_time = 0
def check(self):
if self.acc_time - self.last_acc_time > self.interval:
self.last_acc_time = self.acc_time
return True
return False
time_mark = {} def pretty_print(self):
print(f"\x1b[{self.color}m", end="")
print("-" * self.indent * 2, end="")
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
def mark_start(key): def mark_start(name, interval=0.1, color=0, indent=0):
global time_infos, show_time_cost
if not show_time_cost:
return
torch.cuda.synchronize() torch.cuda.synchronize()
global time_mark if time_infos.get(name, None) is None:
time_mark[key] = time.time() time_infos[name] = TimeInfo(name, interval, color, indent)
return time_infos[name].acc_time -= time.time()
def mark_end(key, print_min_cost=0.0): def mark_end(name):
global time_infos, show_time_cost
if not show_time_cost:
return
torch.cuda.synchronize() torch.cuda.synchronize()
global time_mark time_infos[name].acc_time += time.time()
cost_time = (time.time() - time_mark[key]) * 1000 if time_infos[name].check():
if cost_time > print_min_cost: time_infos[name].pretty_print()
print(f"cost {key}:", cost_time)
def calculate_time(show=False, min_cost_ms=0.0): def calculate_time(show=False, min_cost_ms=0.0):
......
...@@ -66,9 +66,9 @@ class BenchBatch: ...@@ -66,9 +66,9 @@ class BenchBatch:
p_idx = prefix_req_idx[i // fork_num].item() p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item() n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[ req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
n_idx, prefix_len : prefix_len + extend_len self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len] )
def update_decode(self, predict_ids, batch_size): def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size assert predict_ids.shape[0] == batch_size
...@@ -81,9 +81,9 @@ class BenchBatch: ...@@ -81,9 +81,9 @@ class BenchBatch:
self.out_cache_cont_start, self.out_cache_cont_start,
self.out_cache_cont_end, self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size) ) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.req_pool_indices, self.seq_lens self.out_cache_loc
] = self.out_cache_loc )
self.seq_lens.add_(1) self.seq_lens.add_(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment