Unverified Commit 6218034d authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Support FlashInfer backend & Fix CUDA Graph bug [1/2] (#32348)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 77c16df3
...@@ -195,15 +195,19 @@ def get_cudagraph_size( ...@@ -195,15 +195,19 @@ def get_cudagraph_size(
cudagraph_sizes: dict[int, int], cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode, cudagraph_mode: CUDAGraphMode,
) -> int | None: ) -> int | None:
if not cudagraph_mode.has_full_cudagraphs():
# No full CUDA graph is used.
return None
size = cudagraph_sizes.get(num_tokens_after_dp_padding) size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None: if size is None:
# No CUDA graph for this size. # No CUDA graph for this size.
return None return None
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
all_decode = all(x == 1 for x in num_tokens_per_request) is_mixed = any(x > 1 for x in num_tokens_per_request)
if not all_decode: if is_mixed and cudagraph_mode.mixed_mode() != CUDAGraphMode.FULL:
# Prefill is included. # Prefill is included, and this mode doesn't use CUDA graph for it.
return None return None
return size return size
......
...@@ -230,8 +230,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -230,8 +230,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
# TODO(woosuk): Support other backends. # TODO(woosuk): Support other backends.
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()): supported_backends = ("FLASH_ATTN", "FLASHINFER")
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.") for backend in self.attn_backends.values():
backend_name = backend.get_name()
if backend_name not in supported_backends:
raise NotImplementedError(
f"The {backend_name} attention backend is not supported yet. "
f"Supported backends are: {supported_backends}."
)
self.kv_caches: list[torch.Tensor] = [] self.kv_caches: list[torch.Tensor] = []
init_kv_cache( init_kv_cache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment