Unverified Commit 46094e0c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Deprecate --disable-flashinfer and introduce --attention-backend (#1380)

parent 3a6e8b6d
...@@ -139,7 +139,7 @@ sky status --endpoint 30000 sglang ...@@ -139,7 +139,7 @@ sky status --endpoint 30000 sglang
### Common Notes ### Common Notes
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
## Backend: SGLang Runtime (SRT) ## Backend: SGLang Runtime (SRT)
......
...@@ -92,5 +92,5 @@ sky status --endpoint 30000 sglang ...@@ -92,5 +92,5 @@ sky status --endpoint 30000 sglang
</details> </details>
### Common Notes ### Common Notes
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is the default attention kernel backend. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), please disable it by adding `--disable-flashinfer --disable-flashinfer-sampling` and open an issue on GitHub.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
...@@ -61,14 +61,18 @@ class RadixAttention(nn.Module): ...@@ -61,14 +61,18 @@ class RadixAttention(nn.Module):
# Choose backend # Choose backend
if ( if (
not global_server_args_dict.get("disable_flashinfer", False) global_server_args_dict["attention_backend"] == "flashinfer"
and self.qk_head_dim == self.v_head_dim and self.qk_head_dim == self.v_head_dim
): ):
self.extend_forward = self.extend_forward_flashinfer self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer
else: elif global_server_args_dict["attention_backend"] == "triton":
self.extend_forward = self.extend_forward_triton self.extend_forward = self.extend_forward_triton
self.decode_forward = self.decode_forward_triton self.decode_forward = self.decode_forward_triton
else:
raise ValueError(
f"Invalid attention backend: {global_server_args_dict['attention_backend']}"
)
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim: if self.qk_head_dim != self.v_head_dim:
......
...@@ -78,7 +78,7 @@ class Sampler(CustomOp): ...@@ -78,7 +78,7 @@ class Sampler(CustomOp):
probs = self._get_probs(logits, sampling_info) probs = self._get_probs(logits, sampling_info)
if not global_server_args_dict["disable_flashinfer_sampling"]: if global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
...@@ -93,11 +93,15 @@ class Sampler(CustomOp): ...@@ -93,11 +93,15 @@ class Sampler(CustomOp):
batch_next_token_ids, success = flashinfer_top_k_top_p( batch_next_token_ids, success = flashinfer_top_k_top_p(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
) )
else: elif global_server_args_dict["sampling_backend"] == "pytorch":
# Here we provide a slower fallback implementation. # Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
) )
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
return SampleOutput(success, probs, batch_next_token_ids) return SampleOutput(success, probs, batch_next_token_ids)
......
...@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache ...@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput from sglang.srt.layers.sampler import SampleOutput
...@@ -40,10 +41,11 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 ...@@ -40,10 +41,11 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
global_server_args_dict = { global_server_args_dict = {
"disable_flashinfer": False, "attention_backend": ServerArgs.attention_backend,
"disable_flashinfer_sampling": False, "sampling_backend": ServerArgs.sampling_backend,
"triton_attention_reduce_in_fp32": False, "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"enable_mla": False, "enable_mla": ServerArgs.enable_mla,
"torchao_config": ServerArgs.torchao_config,
} }
......
...@@ -128,7 +128,7 @@ class ModelTpServer: ...@@ -128,7 +128,7 @@ class ModelTpServer:
if server_args.max_running_requests is None if server_args.max_running_requests is None
else server_args.max_running_requests else server_args.max_running_requests
), ),
self.model_runner.req_to_token_pool.size - 1, self.model_runner.req_to_token_pool.size,
) )
self.max_req_input_len = min( self.max_req_input_len = min(
self.model_config.context_len - 1, self.model_config.context_len - 1,
......
...@@ -203,17 +203,17 @@ class InputMetadata: ...@@ -203,17 +203,17 @@ class InputMetadata:
ret.compute_extend_infos(batch) ret.compute_extend_infos(batch)
fm = batch.forward_mode fm = batch.forward_mode
if not fm.is_decode() or model_runner.server_args.disable_flashinfer: if not fm.is_decode() or model_runner.server_args.attention_backend == "triton":
ret.total_num_tokens = int(torch.sum(ret.seq_lens)) ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if not fm.is_decode(): if not fm.is_decode():
ret.init_multimuldal_info(batch) ret.init_multimuldal_info(batch)
if model_runner.server_args.disable_flashinfer: if model_runner.server_args.attention_backend == "triton":
ret.init_triton_args(batch) ret.init_triton_args(batch)
flashinfer_use_ragged = False flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer: if model_runner.server_args.attention_backend == "flashinfer":
if ( if (
not fm.is_decode() not fm.is_decode()
and int(torch.sum(ret.seq_lens)) > 4096 and int(torch.sum(ret.seq_lens)) > 4096
......
...@@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -53,7 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
...@@ -92,8 +92,8 @@ class ModelRunner: ...@@ -92,8 +92,8 @@ class ModelRunner:
) )
global_server_args_dict.update( global_server_args_dict.update(
{ {
"disable_flashinfer": server_args.disable_flashinfer, "attention_backend": server_args.attention_backend,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "sampling_backend": server_args.sampling_backend,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla, "enable_mla": server_args.enable_mla,
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
...@@ -111,7 +111,7 @@ class ModelRunner: ...@@ -111,7 +111,7 @@ class ModelRunner:
self.load_model() self.load_model()
self.init_memory_pool( self.init_memory_pool(
min_per_gpu_memory, min_per_gpu_memory,
server_args.max_num_reqs, server_args.max_running_requests,
server_args.max_total_tokens, server_args.max_total_tokens,
) )
self.init_cublas() self.init_cublas()
...@@ -344,8 +344,8 @@ class ModelRunner: ...@@ -344,8 +344,8 @@ class ModelRunner:
def init_memory_pool( def init_memory_pool(
self, self,
total_gpu_memory: int, total_gpu_memory: int,
max_num_reqs: int = None, max_num_reqs: Optional[int] = None,
max_total_tokens: int = None, max_total_tokens: Optional[int] = None,
): ):
if self.server_args.kv_cache_dtype == "auto": if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype self.kv_cache_dtype = self.dtype
...@@ -379,7 +379,7 @@ class ModelRunner: ...@@ -379,7 +379,7 @@ class ModelRunner:
), ),
2048, 2048,
), ),
5120, 4096,
) )
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
...@@ -399,7 +399,7 @@ class ModelRunner: ...@@ -399,7 +399,7 @@ class ModelRunner:
) )
logger.info("using MLA Triton implementaion, flashinfer is disabled") logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported # FIXME: temporarily only Triton MLA is supported
self.server_args.disable_flashinfer = True self.server_args.attention_backend = "triton"
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -424,7 +424,7 @@ class ModelRunner: ...@@ -424,7 +424,7 @@ class ModelRunner:
def init_flashinfer(self): def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers.""" """Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer: if self.server_args.attention_backend != "flashinfer":
assert ( assert (
self.sliding_window_size is None self.sliding_window_size is None
), "turn on flashinfer to support window attention" ), "turn on flashinfer to support window attention"
...@@ -491,7 +491,10 @@ class ModelRunner: ...@@ -491,7 +491,10 @@ class ModelRunner:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: if (
self.server_args.disable_cuda_graph
or self.server_args.attention_backend != "flashinfer"
):
self.cuda_graph_runner = None self.cuda_graph_runner = None
return return
......
...@@ -425,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -425,7 +425,7 @@ def _set_envs_and_config(server_args: ServerArgs):
maybe_set_triton_cache_manager() maybe_set_triton_cache_manager()
# Check flashinfer version # Check flashinfer version
if not server_args.disable_flashinfer: if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.6", "0.1.6",
......
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
import argparse import argparse
import dataclasses import dataclasses
import json
import logging import logging
import random import random
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -50,7 +49,6 @@ class ServerArgs: ...@@ -50,7 +49,6 @@ class ServerArgs:
# Memory and scheduling # Memory and scheduling
mem_fraction_static: Optional[float] = None mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None max_running_requests: Optional[int] = None
max_num_reqs: Optional[int] = None
max_total_tokens: Optional[int] = None max_total_tokens: Optional[int] = None
chunked_prefill_size: int = 8192 chunked_prefill_size: int = 8192
max_prefill_tokens: int = 16384 max_prefill_tokens: int = 16384
...@@ -85,6 +83,9 @@ class ServerArgs: ...@@ -85,6 +83,9 @@ class ServerArgs:
json_model_override_args: str = "{}" json_model_override_args: str = "{}"
# Optimization/debug options # Optimization/debug options
attention_backend: str = "flashinfer"
sampling_backend: str = "flashinfer"
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
...@@ -101,6 +102,7 @@ class ServerArgs: ...@@ -101,6 +102,7 @@ class ServerArgs:
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -111,6 +113,7 @@ class ServerArgs: ...@@ -111,6 +113,7 @@ class ServerArgs:
# Disable chunked prefill # Disable chunked prefill
self.chunked_prefill_size = None self.chunked_prefill_size = None
# Mem fraction depends on the tensor parallelism size
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 16: if self.tp_size >= 16:
self.mem_fraction_static = 0.79 self.mem_fraction_static = 0.79
...@@ -131,6 +134,29 @@ class ServerArgs: ...@@ -131,6 +134,29 @@ class ServerArgs:
if self.random_seed is None: if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30) self.random_seed = random.randint(0, 1 << 30)
# Deprecation warnings
if self.disable_flashinfer:
logger.warning(
"The option '--disable-flashinfer' will be deprecated in the next release. "
"Please use '--attention-backend triton' instead."
)
if self.disable_flashinfer_sampling:
logger.warning(
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
"Please use '--sampling-backend pytorch' instead. "
)
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.attention_backend = "flashinfer"
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
...@@ -214,11 +240,6 @@ class ServerArgs: ...@@ -214,11 +240,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
) )
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument( parser.add_argument(
"--context-length", "--context-length",
type=int, type=int,
...@@ -253,6 +274,11 @@ class ServerArgs: ...@@ -253,6 +274,11 @@ class ServerArgs:
default=ServerArgs.chat_template, default=ServerArgs.chat_template,
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
) )
parser.add_argument(
"--is-embedding",
action="store_true",
help="Whether to use a CausalLM as an embedding model.",
)
parser.add_argument( parser.add_argument(
"--mem-fraction-static", "--mem-fraction-static",
type=float, type=float,
...@@ -265,17 +291,12 @@ class ServerArgs: ...@@ -265,17 +291,12 @@ class ServerArgs:
default=ServerArgs.max_running_requests, default=ServerArgs.max_running_requests,
help="The maximum number of running requests.", help="The maximum number of running requests.",
) )
parser.add_argument(
"--max-num-reqs",
type=int,
default=ServerArgs.max_num_reqs,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
)
parser.add_argument( parser.add_argument(
"--max-total-tokens", "--max-total-tokens",
type=int, type=int,
default=ServerArgs.max_total_tokens, default=ServerArgs.max_total_tokens,
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.", help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
"This option is typically used for development and debugging purposes.",
) )
parser.add_argument( parser.add_argument(
"--chunked-prefill-size", "--chunked-prefill-size",
...@@ -395,15 +416,29 @@ class ServerArgs: ...@@ -395,15 +416,29 @@ class ServerArgs:
) )
# Optimization/debug options # Optimization/debug options
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
parser.add_argument(
"--sampling-backend",
type=str,
choices=["flashinfer", "pytorch"],
default=ServerArgs.sampling_backend,
help="Choose the kernels for sampling layers.",
)
parser.add_argument( parser.add_argument(
"--disable-flashinfer", "--disable-flashinfer",
action="store_true", action="store_true",
help="Disable flashinfer attention kernels.", help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
) )
parser.add_argument( parser.add_argument(
"--disable-flashinfer-sampling", "--disable-flashinfer-sampling",
action="store_true", action="store_true",
help="Disable flashinfer sampling kernels.", help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
) )
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
...@@ -491,14 +526,6 @@ class ServerArgs: ...@@ -491,14 +526,6 @@ class ServerArgs:
assert not ( assert not (
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.disable_flashinfer = False
def prepare_server_args(argv: List[str]) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
......
...@@ -14,13 +14,12 @@ from sglang.test.test_utils import ( ...@@ -14,13 +14,12 @@ from sglang.test.test_utils import (
class TestServingThroughput(unittest.TestCase): class TestServingThroughput(unittest.TestCase):
def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size):
# Launch the server # Launch the server
other_args = [] other_args = []
if disable_radix_cache: if disable_radix_cache:
other_args.append("--disable-radix-cache") other_args.append("--disable-radix-cache")
if disable_flashinfer: other_args.extend(["--attention-backend", attention_backend])
other_args.append("--disable-flashinfer")
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
other_args.extend(["--tensor-parallel-size", "2"]) other_args.extend(["--tensor-parallel-size", "2"])
...@@ -70,7 +69,7 @@ class TestServingThroughput(unittest.TestCase): ...@@ -70,7 +69,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default(self): def test_default(self):
res = self.run_test( res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache, disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer, attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size, chunked_prefill_size=ServerArgs.chunked_prefill_size,
) )
...@@ -80,7 +79,7 @@ class TestServingThroughput(unittest.TestCase): ...@@ -80,7 +79,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default_without_radix_cache(self): def test_default_without_radix_cache(self):
res = self.run_test( res = self.run_test(
disable_radix_cache=True, disable_radix_cache=True,
disable_flashinfer=ServerArgs.disable_flashinfer, attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size, chunked_prefill_size=ServerArgs.chunked_prefill_size,
) )
......
...@@ -14,13 +14,12 @@ from sglang.test.test_utils import ( ...@@ -14,13 +14,12 @@ from sglang.test.test_utils import (
class TestServingThroughput(unittest.TestCase): class TestServingThroughput(unittest.TestCase):
def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size): def run_test(self, disable_radix_cache, attention_backend, chunked_prefill_size):
# Launch the server # Launch the server
other_args = [] other_args = []
if disable_radix_cache: if disable_radix_cache:
other_args.append("--disable-radix-cache") other_args.append("--disable-radix-cache")
if disable_flashinfer: other_args.extend(["--attention-backend", attention_backend])
other_args.append("--disable-flashinfer")
other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)])
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -69,7 +68,7 @@ class TestServingThroughput(unittest.TestCase): ...@@ -69,7 +68,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default(self): def test_default(self):
res = self.run_test( res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache, disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer, attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size, chunked_prefill_size=ServerArgs.chunked_prefill_size,
) )
...@@ -79,7 +78,7 @@ class TestServingThroughput(unittest.TestCase): ...@@ -79,7 +78,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default_without_radix_cache(self): def test_default_without_radix_cache(self):
res = self.run_test( res = self.run_test(
disable_radix_cache=True, disable_radix_cache=True,
disable_flashinfer=ServerArgs.disable_flashinfer, attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=ServerArgs.chunked_prefill_size, chunked_prefill_size=ServerArgs.chunked_prefill_size,
) )
...@@ -89,7 +88,7 @@ class TestServingThroughput(unittest.TestCase): ...@@ -89,7 +88,7 @@ class TestServingThroughput(unittest.TestCase):
def test_default_without_chunked_prefill(self): def test_default_without_chunked_prefill(self):
res = self.run_test( res = self.run_test(
disable_radix_cache=ServerArgs.disable_radix_cache, disable_radix_cache=ServerArgs.disable_radix_cache,
disable_flashinfer=ServerArgs.disable_flashinfer, attention_backend=ServerArgs.attention_backend,
chunked_prefill_size=-1, chunked_prefill_size=-1,
) )
......
...@@ -20,7 +20,7 @@ class TestTritonAttnBackend(unittest.TestCase): ...@@ -20,7 +20,7 @@ class TestTritonAttnBackend(unittest.TestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-flashinfer"], other_args=["--attention-backend", "triton"],
) )
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment