Unverified Commit 697e4ff3 authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[GDN] add a config for gdn kernel selection (#36647)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent a3e2e250
...@@ -614,6 +614,7 @@ class EngineArgs: ...@@ -614,6 +614,7 @@ class EngineArgs:
) )
fail_on_environ_validation: bool = False fail_on_environ_validation: bool = False
gdn_prefill_backend: Literal["flashinfer", "triton"] | None = None
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
...@@ -1318,6 +1319,13 @@ class EngineArgs: ...@@ -1318,6 +1319,13 @@ class EngineArgs:
help="Shutdown timeout in seconds. 0 = abort, >0 = wait.", help="Shutdown timeout in seconds. 0 = abort, >0 = wait.",
) )
parser.add_argument(
"--gdn-prefill-backend",
dest="gdn_prefill_backend",
choices=["flashinfer", "triton"],
default=None,
help="Select GDN prefill backend.",
)
return parser return parser
@classmethod @classmethod
...@@ -1903,6 +1911,9 @@ class EngineArgs: ...@@ -1903,6 +1911,9 @@ class EngineArgs:
), ),
) )
if self.gdn_prefill_backend is not None:
self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend
config = VllmConfig( config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
......
...@@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule( ...@@ -161,13 +161,45 @@ def fi_chunk_gated_delta_rule(
class ChunkGatedDeltaRule(CustomOp): class ChunkGatedDeltaRule(CustomOp):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
if current_platform.is_cuda() and current_platform.is_device_capability(90): backend = (
str(
get_current_vllm_config().additional_config.get(
"gdn_prefill_backend", "auto"
)
)
.strip()
.lower()
)
supports_flashinfer = (
current_platform.is_cuda() and current_platform.is_device_capability(90)
)
if backend == "flashinfer":
use_flashinfer = supports_flashinfer
if not use_flashinfer:
logger.warning_once(
"GDN prefill backend 'flashinfer' is selected but "
"cannot use this kernel on the current platform. "
"Falling back to Triton/FLA."
)
elif backend == "triton":
use_flashinfer = False
else:
use_flashinfer = supports_flashinfer
if use_flashinfer:
logger.info_once("Using FlashInfer GDN prefill kernel")
logger.info_once( logger.info_once(
"Using FlashInfer GDN prefill kernel on CUDA compute capability 90" "FlashInfer GDN prefill kernel is JIT-compiled; first run may "
"take a while to compile. Set `--gdn-prefill-backend triton` to "
"avoid JIT compile time."
) )
self._forward_method = self.forward_cuda
else: else:
self._forward_method = self.forward_native logger.info_once("Using Triton/FLA GDN prefill kernel")
self._forward_method = (
self.forward_cuda if use_flashinfer else self.forward_native
)
def forward_cuda( def forward_cuda(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment