Unverified Commit 46094e0c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Deprecate --disable-flashinfer and introduce --attention-backend (#1380)

parent 3a6e8b6d
......@@ -139,7 +139,7 @@ sky status --endpoint 30000 sglang
### Common Notes
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue.
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
## Backend: SGLang Runtime (SRT)
......
......@@ -92,5 +92,5 @@ sky status --endpoint 30000 sglang
</details>
### Common Notes
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue.
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
......@@ -61,14 +61,18 @@ class RadixAttention(nn.Module):
# Choose backend
if (
not global_server_args_dict.get("disable_flashinfer", False)
global_server_args_dict["attention_backend"] == "flashinfer"
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
else:
elif global_server_args_dict["attention_backend"] == "triton":
self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton
else:
raise ValueError(
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
)
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
......
......@@ -78,7 +78,7 @@ class Sampler(CustomOp):
probs = self._get_probs(logits, sampling_info)
if not global_server_args_dict["disable_flashinfer_sampling"]:
if global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
......@@ -93,11 +93,15 @@ class Sampler(CustomOp):
batch_next_token_ids, success = flashinfer_top_k_top_p(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
)
else:
elif global_server_args_dict["sampling_backend"] == "pytorch":
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
return SampleOutput(success, probs, batch_next_token_ids)
......
......@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
......@@ -40,10 +41,11 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"triton_attention_reduce_in_fp32": False,
"enable_mla": False,
"attention_backend": ServerArgs.attention_backend,
"sampling_backend": ServerArgs.sampling_backend,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"enable_mla": ServerArgs.enable_mla,
"torchao_config": ServerArgs.torchao_config,
}
......
......@@ -128,7 +128,7 @@ class ModelTpServer:
if server_args.max_running_requests is None
else server_args.max_running_requests
),
self.model_runner.req_to_token_pool.size - 1,
self.model_runner.req_to_token_pool.size,
)
self.max_req_input_len = min(
self.model_config.context_len - 1,
......
......@@ -203,17 +203,17 @@ class InputMetadata:
ret.compute_extend_infos(batch)
fm = batch.forward_mode
if not fm.is_decode() or model_runner.server_args.disable_flashinfer:
if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if not fm.is_decode():
ret.init_multimuldal_info(batch)
if model_runner.server_args.disable_flashinfer:
if model_runner.server_args.attention_backend == "triton":
ret.init_triton_args(batch)
flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer:
if model_runner.server_args.attention_backend == "flashinfer":
if (
not fm.is_decode()
and int(torch.sum(ret.seq_lens)) > 4096
......
......@@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
......@@ -92,8 +92,8 @@ class ModelRunner:
)
global_server_args_dict.update(
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_backend": server_args.attention_backend,
"sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
"torchao_config": server_args.torchao_config,
......@@ -111,7 +111,7 @@ class ModelRunner:
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_running_requests,
server_args.max_total_tokens,
)
self.init_cublas()
......@@ -344,8 +344,8 @@ class ModelRunner:
def init_memory_pool(
self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
......@@ -379,7 +379,7 @@ class ModelRunner:
),
2048,
),
5120,
4096,
)
self.req_to_token_pool = ReqToTokenPool(
......@@ -399,7 +399,7 @@ class ModelRunner:
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.disable_flashinfer = True
self.server_args.attention_backend = "triton"
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
......@@ -424,7 +424,7 @@ class ModelRunner:
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer:
if self.server_args.attention_backend != "flashinfer":
assert (
self.sliding_window_size is None
), "turn on flashinfer to support window attention"
......@@ -491,7 +491,10 @@ class ModelRunner:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
if (
self.server_args.disable_cuda_graph
or self.server_args.attention_backend != "flashinfer"
):
self.cuda_graph_runner = None
return
......
......@@ -425,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs):
maybe_set_triton_cache_manager()
# Check flashinfer version
if not server_args.disable_flashinfer:
if server_args.attention_backend == "flashinfer":
assert_pkg_version(
"flashinfer",
"0.1.6",
......
......@@ -17,7 +17,6 @@ limitations under the License.
import argparse
import dataclasses
import json
import logging
import random
from typing import List, Optional, Union
......@@ -50,7 +49,6 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_num_reqs: Optional[int] = None
max_total_tokens: Optional[int] = None
chunked_prefill_size: int = 8192
max_prefill_tokens: int = 16384
......@@ -85,6 +83,9 @@ class ServerArgs:
json_model_override_args: str = "{}"
# Optimization/debug options
attention_backend: str = "flashinfer"
sampling_backend: str = "flashinfer"
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False
......@@ -101,6 +102,7 @@ class ServerArgs:
triton_attention_reduce_in_fp32: bool = False
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
......@@ -111,6 +113,7 @@ class ServerArgs:
# Disable chunked prefill
self.chunked_prefill_size = None
# Mem fraction depends on the tensor parallelism size
if self.mem_fraction_static is None:
if self.tp_size >= 16:
self.mem_fraction_static = 0.79
......@@ -131,6 +134,29 @@ class ServerArgs:
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
# Deprecation warnings
if self.disable_flashinfer:
logger.warning(
"The option '--disable-flashinfer' will be deprecated in the next release. "
"Please use '--attention-backend triton' instead."
)
if self.disable_flashinfer_sampling:
logger.warning(
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
"Please use '--sampling-backend pytorch' instead. "
)
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.attention_backend = "flashinfer"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
......@@ -214,11 +240,6 @@ class ServerArgs:
action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--context-length",
type=int,
......@@ -253,6 +274,11 @@ class ServerArgs:
default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
)
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument(
"--mem-fraction-static",
type=float,
......@@ -265,17 +291,12 @@ class ServerArgs:
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
parser.add_argument(
"--max-num-reqs",
type=int,
default=ServerArgs.max_num_reqs,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
)
parser.add_argument(
"--max-total-tokens",
type=int,
default=ServerArgs.max_total_tokens,
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
"This option is typically used for development and debugging purposes.",
)
parser.add_argument(
"--chunked-prefill-size",
......@@ -395,15 +416,29 @@ class ServerArgs:
)
# Optimization/debug options
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--sampling-backend",
type=str,
choices=["flashinfer", "pytorch"],
default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.",
)
parser.add_argument(
"--disable-flashinfer",
action="store_true",
help="Disable flashinfer attention kernels.",
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels.",
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
)
parser.add_argument(
"--disable-radix-cache",
......@@ -491,14 +526,6 @@ class ServerArgs:
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False
def prepare_server_args(argv: List[str]) -> ServerArgs:
......
......@@ -14,13 +14,12 @@ from sglang.test.test_utils import (
class TestServingThroughput(unittest.TestCase):
def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size):
def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size):
# Launch the server
other_args = []
if disable_radix_cache:
other_args.append("--disable-radix-cache")
if disable_flashinfer:
other_args.append("--disable-flashinfer")
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
other_args.extend(["--tensor-parallel-size", "2"])
......@@ -70,7 +69,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer,
attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
)
......@@ -80,7 +79,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default_without_radix_cache(self):
res = self.run_test(
disable_radix_cache=True,
disable_flashinfer=ServerArgs.disable_flashinfer,
attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
)
......
......@@ -14,13 +14,12 @@ from sglang.test.test_utils import (
class TestServingThroughput(unittest.TestCase):
def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size):
def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size):
# Launch the server
other_args = []
if disable_radix_cache:
other_args.append("--disable-radix-cache")
if disable_flashinfer:
other_args.append("--disable-flashinfer")
other_args.extend(["--attention-backend", attention_backend])
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
model = DEFAULT_MODEL_NAME_FOR_TEST
......@@ -69,7 +68,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer,
attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
)
......@@ -79,7 +78,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default_without_radix_cache(self):
res = self.run_test(
disable_radix_cache=True,
disable_flashinfer=ServerArgs.disable_flashinfer,
attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size,
)
......@@ -89,7 +88,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default_without_chunked_prefill(self):
res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer,
attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=-1,
)
......
......@@ -20,7 +20,7 @@ class TestTritonAttnBackend(unittest.TestCase):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-flashinfer"],
other_args=["--attention-backend", "triton"],
)
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment