Unverified Commit 28103384 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

[feat] Support different attention backends for prefill and decode (#6338)


Co-authored-by: default avatartianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
parent fe6a445d
...@@ -188,6 +188,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -188,6 +188,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--attention-backend` | Choose the kernels for attention layers. | None | | `--attention-backend` | Choose the kernels for attention layers. | None |
| `decode_attention_backend` | (Experimental) This argument specifies the backend for decode attention computation. Note that this argument has priority over `attention_backend`. | None |
| `prefill_attention_backend` | (Experimental) This argument specifies the backend for prefill attention computation. Note that this argument has priority over `attention_backend`. | None |
| `--sampling-backend` | Choose the kernels for sampling layers. | None | | `--sampling-backend` | Choose the kernels for sampling layers. | None |
| `--grammar-backend` | Choose the backend for grammar-guided decoding. | None | | `--grammar-backend` | Choose the backend for grammar-guided decoding. | None |
| `--mm-attention-backend` | Set multimodal attention backend. | None | | `--mm-attention-backend` | Set multimodal attention backend. | None |
......
from typing import TYPE_CHECKING, Optional, Union
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
class HybridAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
def __init__(
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
):
self.prefill_backend = prefill_backend
self.decode_backend = decode_backend
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
self.decode_backend.init_forward_metadata(forward_batch)
else:
self.prefill_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.decode_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.decode_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
return self.decode_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
return self.prefill_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
...@@ -1690,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1690,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
if self.forward_mode.is_decode_or_idle():
attention_backend_str = global_server_args_dict["decode_attention_backend"]
else:
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
# Create seq_lens_cpu when needed # Create seq_lens_cpu when needed
if ( if (
global_server_args_dict["attention_backend"] == "fa3" attention_backend_str == "fa3"
or ( or (
global_server_args_dict["use_mla_backend"] global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer" and attention_backend_str == "flashinfer"
) )
or global_server_args_dict["attention_backend"] == "flashmla" or attention_backend_str == "flashmla"
or global_server_args_dict["attention_backend"] == "cutlass_mla" or attention_backend_str == "cutlass_mla"
or global_server_args_dict["attention_backend"] == "ascend" or attention_backend_str == "ascend"
or global_server_args_dict["enable_two_batch_overlap"] or global_server_args_dict["enable_two_batch_overlap"]
): ):
seq_lens_cpu = ( seq_lens_cpu = (
......
...@@ -1308,9 +1308,58 @@ class ModelRunner: ...@@ -1308,9 +1308,58 @@ class ModelRunner:
else: else:
self.attn_backend = self._get_attention_backend() self.attn_backend = self._get_attention_backend()
# TODO unify with 6338
def _get_attention_backend(self): def _get_attention_backend(self):
if self.server_args.attention_backend == "flashinfer": """Init attention kernel backend."""
self.decode_attention_backend_str = (
self.server_args.decode_attention_backend
if self.server_args.decode_attention_backend
else self.server_args.attention_backend
)
self.prefill_attention_backend_str = (
self.server_args.prefill_attention_backend
if self.server_args.prefill_attention_backend
else self.server_args.attention_backend
)
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
assert (
self.server_args.speculative_algorithm is None
), "Currently HybridAttentionBackend does not support speculative decoding."
from sglang.srt.layers.attention.hybrid_attn_backend import (
HybridAttnBackend,
)
attn_backend = HybridAttnBackend(
decode_backend=self._get_attention_backend_from_str(
self.decode_attention_backend_str
),
prefill_backend=self._get_attention_backend_from_str(
self.prefill_attention_backend_str
),
)
logger.info(
f"Using hybrid attention backend for decode and prefill: "
f"decode_backend={self.decode_attention_backend_str}, "
f"prefill_backend={self.prefill_attention_backend_str}."
)
logger.warning(
f"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
f"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
)
else:
attn_backend = self._get_attention_backend_from_str(
self.server_args.attention_backend
)
global_server_args_dict.update(
{
"decode_attention_backend": self.decode_attention_backend_str,
"prefill_attention_backend": self.prefill_attention_backend_str,
}
)
return attn_backend
def _get_attention_backend_from_str(self, backend_str: str):
if backend_str == "flashinfer":
if not self.use_mla_backend: if not self.use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend, FlashInferAttnBackend,
...@@ -1318,7 +1367,11 @@ class ModelRunner: ...@@ -1318,7 +1367,11 @@ class ModelRunner:
# Init streams # Init streams
if self.server_args.speculative_algorithm == "EAGLE": if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream() if (
not hasattr(self, "plan_stream_for_flashinfer")
or not self.plan_stream_for_flashinfer
):
self.plan_stream_for_flashinfer = torch.cuda.Stream()
return FlashInferAttnBackend(self) return FlashInferAttnBackend(self)
else: else:
from sglang.srt.layers.attention.flashinfer_mla_backend import ( from sglang.srt.layers.attention.flashinfer_mla_backend import (
...@@ -1326,15 +1379,15 @@ class ModelRunner: ...@@ -1326,15 +1379,15 @@ class ModelRunner:
) )
return FlashInferMLAAttnBackend(self) return FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "aiter": elif backend_str == "aiter":
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
return AiterAttnBackend(self) return AiterAttnBackend(self)
elif self.server_args.attention_backend == "ascend": elif backend_str == "ascend":
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
return AscendAttnBackend(self) return AscendAttnBackend(self)
elif self.server_args.attention_backend == "triton": elif backend_str == "triton":
assert not self.model_config.is_encoder_decoder, ( assert not self.model_config.is_encoder_decoder, (
"Cross attention is not supported in the triton attention backend. " "Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`." "Please use `--attention-backend flashinfer`."
...@@ -1349,17 +1402,17 @@ class ModelRunner: ...@@ -1349,17 +1402,17 @@ class ModelRunner:
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
return TritonAttnBackend(self) return TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native": elif backend_str == "torch_native":
from sglang.srt.layers.attention.torch_native_backend import ( from sglang.srt.layers.attention.torch_native_backend import (
TorchNativeAttnBackend, TorchNativeAttnBackend,
) )
return TorchNativeAttnBackend(self) return TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashmla": elif backend_str == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
return FlashMLABackend(self) return FlashMLABackend(self)
elif self.server_args.attention_backend == "fa3": elif backend_str == "fa3":
assert ( assert (
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
) or torch.cuda.get_device_capability()[0] == 9, ( ) or torch.cuda.get_device_capability()[0] == 9, (
...@@ -1371,7 +1424,7 @@ class ModelRunner: ...@@ -1371,7 +1424,7 @@ class ModelRunner:
) )
return FlashAttentionBackend(self) return FlashAttentionBackend(self)
elif self.server_args.attention_backend == "cutlass_mla": elif backend_str == "cutlass_mla":
from sglang.srt.layers.attention.cutlass_mla_backend import ( from sglang.srt.layers.attention.cutlass_mla_backend import (
CutlassMLABackend, CutlassMLABackend,
) )
...@@ -1385,9 +1438,7 @@ class ModelRunner: ...@@ -1385,9 +1438,7 @@ class ModelRunner:
logger.info(f"Intel AMX attention backend is enabled.") logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self) return IntelAMXAttnBackend(self)
else: else:
raise ValueError( raise ValueError(f"Invalid attention backend: {backend_str}")
f"Invalid attention backend: {self.server_args.attention_backend}"
)
def init_double_sparsity_channel_config(self, selected_channel): def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj" selected_channel = "." + selected_channel + "_proj"
...@@ -1475,7 +1526,10 @@ class ModelRunner: ...@@ -1475,7 +1526,10 @@ class ModelRunner:
if self.support_pp: if self.support_pp:
kwargs["pp_proxy_tensors"] = pp_proxy_tensors kwargs["pp_proxy_tensors"] = pp_proxy_tensors
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs forward_batch.input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
) )
def forward_extend( def forward_extend(
......
...@@ -925,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -925,7 +925,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.disable_chunked_prefix_cache = global_server_args_dict[ self.disable_chunked_prefix_cache = global_server_args_dict[
"disable_chunked_prefix_cache" "disable_chunked_prefix_cache"
] ]
self.attention_backend = global_server_args_dict["attention_backend"]
self.current_attention_backend = (
None # Attention backend used by current forward batch
)
self.rocm_fused_decode_mla = get_bool_env_var( self.rocm_fused_decode_mla = get_bool_env_var(
"SGLANG_ROCM_FUSED_DECODE_MLA", "false" "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
) )
...@@ -1009,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1009,9 +1012,16 @@ class DeepseekV2AttentionMLA(nn.Module):
else: else:
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
if self.attention_backend == "ascend": # Determine attention backend used by current forward batch
if forward_batch.forward_mode.is_decode_or_idle():
attention_backend = global_server_args_dict["decode_attention_backend"]
else:
attention_backend = global_server_args_dict["prefill_attention_backend"]
self.current_attention_backend = attention_backend
if attention_backend == "ascend":
return AttnForwardMethod.MLA return AttnForwardMethod.MLA
elif self.attention_backend == "flashinfer": elif attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
if ( if (
not self.flashinfer_mla_disable_ragged not self.flashinfer_mla_disable_ragged
...@@ -1023,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1023,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA return AttnForwardMethod.MHA
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
elif self.attention_backend == "fa3": elif attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None: if forward_batch.extend_prefix_lens_cpu is not None:
sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu) sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
...@@ -1040,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1040,7 +1050,7 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
elif self.attention_backend == "aiter": elif attention_backend == "aiter":
if ( if (
forward_batch.forward_mode.is_extend() forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
...@@ -1288,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1288,9 +1298,9 @@ class DeepseekV2AttentionMLA(nn.Module):
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
): ):
if ( if (
self.attention_backend == "fa3" self.current_attention_backend == "fa3"
or self.attention_backend == "flashinfer" or self.current_attention_backend == "flashinfer"
or self.attention_backend == "cutlass_mla" or self.current_attention_backend == "cutlass_mla"
): ):
attn_output = self.attn_mqa( attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
......
...@@ -151,6 +151,8 @@ class ServerArgs: ...@@ -151,6 +151,8 @@ class ServerArgs:
# Kernel backend # Kernel backend
attention_backend: Optional[str] = None attention_backend: Optional[str] = None
decode_attention_backend: Optional[str] = None
prefill_attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = None grammar_backend: Optional[str] = None
mm_attention_backend: Optional[str] = None mm_attention_backend: Optional[str] = None
...@@ -387,13 +389,19 @@ class ServerArgs: ...@@ -387,13 +389,19 @@ class ServerArgs:
) )
self.page_size = 128 self.page_size = 128
if self.attention_backend == "flashmla": if (
self.attention_backend == "flashmla"
or self.decode_attention_backend == "flashmla"
):
logger.warning( logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64." "FlashMLA only supports a page_size of 64, change page_size to 64."
) )
self.page_size = 64 self.page_size = 64
if self.attention_backend == "cutlass_mla": if (
self.attention_backend == "cutlass_mla"
or self.decode_attention_backend == "cutlass_mla"
):
logger.warning( logger.warning(
"Cutlass MLA only supports a page_size of 128, change page_size to 128." "Cutlass MLA only supports a page_size of 128, change page_size to 128."
) )
...@@ -1213,6 +1221,35 @@ class ServerArgs: ...@@ -1213,6 +1221,35 @@ class ServerArgs:
default=ServerArgs.attention_backend, default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.", help="Choose the kernels for attention layers.",
) )
parser.add_argument(
"--decode-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.decode_attention_backend,
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
)
parser.add_argument(
"--prefill-attention-backend",
type=str,
choices=[
"flashinfer",
"triton",
"torch_native",
"fa3",
"flashmla",
"cutlass_mla",
],
default=ServerArgs.prefill_attention_backend,
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
)
parser.add_argument( parser.add_argument(
"--sampling-backend", "--sampling-backend",
type=str, type=str,
......
...@@ -491,6 +491,8 @@ class SRTRunner: ...@@ -491,6 +491,8 @@ class SRTRunner:
lora_paths: List[str] = None, lora_paths: List[str] = None,
max_loras_per_batch: int = 4, max_loras_per_batch: int = 4,
attention_backend: Optional[str] = None, attention_backend: Optional[str] = None,
prefill_attention_backend: Optional[str] = None,
decode_attention_backend: Optional[str] = None,
lora_backend: str = "triton", lora_backend: str = "triton",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
...@@ -540,6 +542,8 @@ class SRTRunner: ...@@ -540,6 +542,8 @@ class SRTRunner:
max_loras_per_batch=max_loras_per_batch, max_loras_per_batch=max_loras_per_batch,
lora_backend=lora_backend, lora_backend=lora_backend,
attention_backend=attention_backend, attention_backend=attention_backend,
prefill_attention_backend=prefill_attention_backend,
decode_attention_backend=decode_attention_backend,
disable_cuda_graph=disable_cuda_graph, disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache, disable_radix_cache=disable_radix_cache,
chunked_prefill_size=chunked_prefill_size, chunked_prefill_size=chunked_prefill_size,
......
...@@ -109,6 +109,7 @@ suites = { ...@@ -109,6 +109,7 @@ suites = {
TestFile("test_vision_openai_server_b.py", 620), TestFile("test_vision_openai_server_b.py", 620),
TestFile("test_w8a8_quantization.py", 46), TestFile("test_w8a8_quantization.py", 46),
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
TestFile("test_hybrid_attn_backend.py", 100),
], ],
"per-commit-amd": [ "per-commit-amd": [
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
......
import os
import unittest
from types import SimpleNamespace
import requests
from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
GSM_DATASET_PATH = None
# Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [
"--trust-remote-code",
"--cuda-graph-max-bs",
"8",
"--prefill-attention-backend",
"fa3",
"--decode-attention-backend",
"flashinfer",
]
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class TestHybridAttnBackendBase(CustomTestCase):
model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
accuracy_threshold = 0.65 # derived tests need to override this
speculative_decode = False
spec_decode_threshold = 1.0 # derived spec decoding tests need to override this
@classmethod
def get_server_args(cls):
"""Return the arguments for the server launch. Override in subclasses."""
return DEFAULT_SERVER_ARGS
@classmethod
def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=4,
num_questions=100,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
data_path=GSM_DATASET_PATH,
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")
# Use the appropriate metric key based on the test class
metric_key = "accuracy"
self.assertGreater(metrics[metric_key], self.accuracy_threshold)
if self.speculative_decode:
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
class TestHybridAttnBackendMLA(TestHybridAttnBackendBase):
accuracy_threshold = 0.60
model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS
class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
accuracy_threshold = 0.65
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment