Unverified Commit 2d346a57 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix padding in the cuda graph (#1469)

parent 446ea332
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**join Slack!**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) | [**join Development Meeting!**](https://calendar.app.google/v2Tw3kuHkKYyp8VV7) | | [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) | [**Join Weekly Development Meeting**](https://calendar.app.google/v2Tw3kuHkKYyp8VV7) |
SGLang is a fast serving framework for large language models and vision language models. SGLang is a fast serving framework for large language models and vision language models.
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
...@@ -20,7 +20,7 @@ The core features include: ...@@ -20,7 +20,7 @@ The core features include:
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). - **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. - **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models. - **Extensive Model Support**: Supports a wide range of generative models (Llama 3, Gemma 2, Mistral, QWen, DeepSeek, LLaVA, etc.) and embedding models (e5-mistral), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption, welcoming contributions to improve LLM and VLM serving. - **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
## News ## News
- [2024/09] 🔥 SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/09] 🔥 SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
......
...@@ -108,6 +108,10 @@ class CudaGraphRunner: ...@@ -108,6 +108,10 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128] self.capture_bs = list(range(1, 32)) + [64, 128]
else: else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
]
self.compile_bs = ( self.compile_bs = (
[ [
bs bs
...@@ -118,21 +122,8 @@ class CudaGraphRunner: ...@@ -118,21 +122,8 @@ class CudaGraphRunner:
else [] else []
) )
# Common inputs
self.max_bs = max(self.capture_bs)
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.out_cache_loc = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
...@@ -141,6 +132,16 @@ class CudaGraphRunner: ...@@ -141,6 +132,16 @@ class CudaGraphRunner:
if self.use_torch_compile: if self.use_torch_compile:
set_torch_compile_config() set_torch_compile_config()
# Common inputs
with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
# Capture # Capture
try: try:
self.capture() self.capture()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment