Unverified Commit 08c4d764 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

lazy import attn backends (#4200)

parent 96d0e37f
...@@ -6,9 +6,7 @@ import torch ...@@ -6,9 +6,7 @@ import torch
import triton import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......
...@@ -302,7 +302,7 @@ class CudaGraphRunner: ...@@ -302,7 +302,7 @@ class CudaGraphRunner:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
# Reverse the order to enable better memory sharing across cuda graphs. # Reverse the order to enable better memory sharing across cuda graphs.
capture_range = ( capture_range = (
tqdm.tqdm(reversed(self.capture_bs)) tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0 if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs) else reversed(self.capture_bs)
) )
......
...@@ -35,11 +35,6 @@ from sglang.srt.distributed import ( ...@@ -35,11 +35,6 @@ from sglang.srt.distributed import (
set_custom_all_reduce, set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_group, get_attention_tp_group,
get_attention_tp_size, get_attention_tp_size,
...@@ -77,7 +72,6 @@ from sglang.srt.utils import ( ...@@ -77,7 +72,6 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
set_cuda_arch, set_cuda_arch,
) )
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -779,6 +773,10 @@ class ModelRunner: ...@@ -779,6 +773,10 @@ class ModelRunner:
def init_attention_backend(self): def init_attention_backend(self):
"""Init attention kernel backend.""" """Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer": if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
# Init streams # Init streams
if self.server_args.speculative_algorithm == "EAGLE": if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream() self.plan_stream_for_flashinfer = torch.cuda.Stream()
...@@ -794,12 +792,26 @@ class ModelRunner: ...@@ -794,12 +792,26 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
) )
if self.server_args.enable_double_sparsity: if self.server_args.enable_double_sparsity:
from sglang.srt.layers.attention.double_sparsity_backend import (
DoubleSparseAttnBackend,
)
self.attn_backend = DoubleSparseAttnBackend(self) self.attn_backend = DoubleSparseAttnBackend(self)
else: else:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
self.attn_backend = TritonAttnBackend(self) self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native": elif self.server_args.attention_backend == "torch_native":
from sglang.srt.layers.attention.torch_native_backend import (
TorchNativeAttnBackend,
)
self.attn_backend = TorchNativeAttnBackend(self) self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla": elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self) self.attn_backend = FlashInferMLAAttnBackend(self)
else: else:
raise ValueError( raise ValueError(
......
...@@ -108,7 +108,7 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -108,7 +108,7 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_eos_token(self, engine): def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]" prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
params = { params = {
"temperature": 0, "temperature": 0.1,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"skip_special_tokens": False, "skip_special_tokens": False,
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment