Unverified Commit 08c4d764 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

lazy import attn backends (#4200)

parent 96d0e37f
......@@ -6,9 +6,7 @@ import torch
import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
......
......@@ -302,7 +302,7 @@ class CudaGraphRunner:
self.stream = graph_capture_context.stream
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(reversed(self.capture_bs))
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
......
......@@ -35,11 +35,6 @@ from sglang.srt.distributed import (
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
......@@ -77,7 +72,6 @@ from sglang.srt.utils import (
set_cpu_offload_max_bytes,
set_cuda_arch,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
......@@ -779,6 +773,10 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
......@@ -794,12 +792,26 @@ class ModelRunner:
"Please use `--attention-backend flashinfer`."
)
if self.server_args.enable_double_sparsity:
from sglang.srt.layers.attention.double_sparsity_backend import (
DoubleSparseAttnBackend,
)
self.attn_backend = DoubleSparseAttnBackend(self)
else:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
from sglang.srt.layers.attention.torch_native_backend import (
TorchNativeAttnBackend,
)
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(
......
......@@ -108,7 +108,7 @@ class TestEAGLEEngine(unittest.TestCase):
def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
params = {
"temperature": 0,
"temperature": 0.1,
"max_new_tokens": 1024,
"skip_special_tokens": False,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment