Commit c0ae70c8 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Improve logging & fix litellm dependency. (#512)

parent 87260b7b
...@@ -9,9 +9,9 @@ ...@@ -9,9 +9,9 @@
SGLang is a structured generation language designed for large language models (LLMs). SGLang is a structured generation language designed for large language models (LLMs).
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system. It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
The core features of SGLang include: The core features include:
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction. - **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism. - **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatically reusing the KV cache across multiple calls. It can also be used as a standalone serving engine with all common techniques implemented, such as continuous batching and tensor parallelism.
## News ## News
- [2024/02] 🔥 SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). - [2024/02] 🔥 SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
......
...@@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro ...@@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro
When the server is running at full load, look for the following in the log: When the server is running at full load, look for the following in the log:
```[gpu_id=0] #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` ```[gpu_id=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417```
### Tune Your Request Submission Speed ### Tune Your Request Submission Speed
`#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed. `#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed.
......
...@@ -9,6 +9,7 @@ try: ...@@ -9,6 +9,7 @@ try:
import litellm import litellm
except ImportError as e: except ImportError as e:
litellm = e litellm = e
litellm.num_retries = 1
class LiteLLM(BaseBackend): class LiteLLM(BaseBackend):
......
...@@ -111,7 +111,10 @@ class ModelTpServer: ...@@ -111,7 +111,10 @@ class ModelTpServer:
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
) )
if self.tp_rank == 0: if self.tp_rank == 0:
logger.info(f"server_args: {server_args.print_mode_args()}") logger.info(
f"[gpu_id={self.gpu_id}] "
f"server_args: {server_args.print_mode_args()}"
)
# Init cache # Init cache
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
...@@ -226,7 +229,7 @@ class ModelTpServer: ...@@ -226,7 +229,7 @@ class ModelTpServer:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
logger.info( logger.info(
f"[gpu_id={self.gpu_id}] " f"[gpu_id={self.gpu_id}] Decode batch. "
f"#running-req: {len(self.running_batch.reqs)}, " f"#running-req: {len(self.running_batch.reqs)}, "
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
...@@ -397,12 +400,13 @@ class ModelTpServer: ...@@ -397,12 +400,13 @@ class ModelTpServer:
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
) )
logger.info( logger.info(
f"new fill batch. #seq: {len(can_run_list)}. " f"[gpu_id={self.gpu_id}] Prefil batch. "
f"#cached_token: {hit_tokens}. " f"#new-seq: {len(can_run_list)}, "
f"#new_token: {new_batch_input_tokens}. " f"#new-token: {new_batch_input_tokens}, "
f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. " f"#cached-token: {hit_tokens}, "
f"#running_req: {running_req}. " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%. " f"#running-req: {running_req}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
) )
# logger.debug( # logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment