Unverified Commit 11f881d1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Deprecate --disable-flashinfer and --disable-flashinfer-sampling (#2065)

parent 38625e21
...@@ -116,8 +116,6 @@ class ServerArgs: ...@@ -116,8 +116,6 @@ class ServerArgs:
grammar_backend: Optional[str] = "outlines" grammar_backend: Optional[str] = "outlines"
# Optimization/debug options # Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_jump_forward: bool = False disable_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
...@@ -179,20 +177,6 @@ class ServerArgs: ...@@ -179,20 +177,6 @@ class ServerArgs:
self.chunked_prefill_size //= 4 # make it 2048 self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4 self.cuda_graph_max_bs = 4
# Deprecation warnings
if self.disable_flashinfer:
logger.warning(
"The option '--disable-flashinfer' will be deprecated in the next release. "
"Please use '--attention-backend triton' instead."
)
self.attention_backend = "triton"
if self.disable_flashinfer_sampling:
logger.warning(
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
"Please use '--sampling-backend pytorch' instead. "
)
self.sampling_backend = "pytorch"
if not is_flashinfer_available(): if not is_flashinfer_available():
self.attention_backend = "triton" self.attention_backend = "triton"
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
...@@ -615,16 +599,6 @@ class ServerArgs: ...@@ -615,16 +599,6 @@ class ServerArgs:
) )
# Optimization/debug options # Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
action="store_true",
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
)
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
action="store_true", action="store_true",
...@@ -733,6 +707,18 @@ class ServerArgs: ...@@ -733,6 +707,18 @@ class ServerArgs:
help="Delete the model checkpoint after loading the model.", help="Delete the model checkpoint after loading the model.",
) )
# Deprecated arguments
parser.add_argument(
"--disable-flashinfer",
action=DeprecatedAction,
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action=DeprecatedAction,
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
...@@ -826,3 +812,13 @@ class LoRAPathAction(argparse.Action): ...@@ -826,3 +812,13 @@ class LoRAPathAction(argparse.Action):
getattr(namespace, self.dest)[name] = path getattr(namespace, self.dest)[name] = path
else: else:
getattr(namespace, self.dest)[lora_path] = lora_path getattr(namespace, self.dest)[lora_path] = lora_path
class DeprecatedAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=0, **kwargs):
super(DeprecatedAction, self).__init__(
option_strings, dest, nargs=nargs, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help)
...@@ -71,6 +71,8 @@ def is_flashinfer_available(): ...@@ -71,6 +71,8 @@ def is_flashinfer_available():
Check whether flashinfer is available. Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs. As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
""" """
if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
return False
return torch.cuda.is_available() and not is_hip() return torch.cuda.is_available() and not is_hip()
......
...@@ -65,8 +65,7 @@ class TestTorchCompile(unittest.TestCase): ...@@ -65,8 +65,7 @@ class TestTorchCompile(unittest.TestCase):
tok = time.time() tok = time.time()
print(f"{res=}") print(f"{res=}")
throughput = max_tokens / (tok - tic) throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s") self.assertGreaterEqual(throughput, 285)
self.assertGreaterEqual(throughput, 290)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment