"torchvision/transforms/_functional_tensor.py" did not exist on "b56f17ae1ae8a5d08067c7f7444af21fb3b59ca6"
Unverified Commit a1816187 authored by weiliang's avatar weiliang Committed by GitHub
Browse files

Fix Flashinfer Backend for SM120 Usage (#12325)

parent e39628fd
...@@ -26,8 +26,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -26,8 +26,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_env_var, get_int_env_var,
is_blackwell_supported,
is_flashinfer_available, is_flashinfer_available,
is_sm100_supported,
next_power_of_2, next_power_of_2,
) )
...@@ -229,7 +229,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -229,7 +229,7 @@ class FlashInferAttnBackend(AttentionBackend):
] ]
fmha_backend = "auto" fmha_backend = "auto"
if is_blackwell_supported(): if is_sm100_supported():
# Disable CUTLASS backend when piecewise cuda graph is enabled # Disable CUTLASS backend when piecewise cuda graph is enabled
# due to TMA descriptor initialization issues on B200 # due to TMA descriptor initialization issues on B200
if model_runner.server_args.enable_piecewise_cuda_graph: if model_runner.server_args.enable_piecewise_cuda_graph:
......
...@@ -25,8 +25,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo ...@@ -25,8 +25,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import ( from sglang.srt.utils import (
is_blackwell_supported,
is_flashinfer_available, is_flashinfer_available,
is_sm100_supported,
next_power_of_2, next_power_of_2,
) )
...@@ -242,9 +242,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): ...@@ -242,9 +242,11 @@ class FlashInferMLAAttnBackend(AttentionBackend):
else: else:
self.q_indptr_decode = q_indptr_decode_buf self.q_indptr_decode = q_indptr_decode_buf
self.fmha_backend = "auto" if is_sm100_supported():
if is_blackwell_supported():
self.fmha_backend = "cutlass" self.fmha_backend = "cutlass"
else:
self.fmha_backend = "auto"
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD", backend=self.fmha_backend self.workspace_buffer, "NHD", backend=self.fmha_backend
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment