Unverified Commit 53f7874a authored by valarLip's avatar valarLip Committed by GitHub
Browse files

refine aiter_backend for mtp (#7279)


Co-authored-by: default avatarHAI <hixiao@gmail.com>
parent 61a46804
......@@ -32,7 +32,7 @@ try:
mha_batch_prefill_func,
paged_attention_ragged,
)
from aiter.mla import mla_decode_fwd
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
......@@ -52,10 +52,8 @@ class ForwardMetadata:
kv_indices: torch.Tensor
qo_indptr: torch.Tensor
kv_last_page_len: torch.Tensor
max_extend_len: int
max_prefix_extend_len: int
max_q_len: int
max_kv_len: int
max_kv_len: Optional[int]
global_workspace_buffer = None
......@@ -71,10 +69,17 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr_buf: Optional[torch.Tensor] = None,
):
super().__init__()
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
......@@ -157,13 +162,13 @@ class AiterAttnBackend(AttentionBackend):
spec_info = forward_batch.spec_info
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
max_q_len = None
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
......@@ -183,39 +188,35 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
kv_last_page_len = self.kv_last_page_len[:bs]
max_extend_len = 1
max_q_len = 1
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
max_q_len,
None,
)
elif forward_batch.forward_mode.is_draft_extend():
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
self.req_to_token,
)
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
kv_indptr,
kv_indices,
qo_indptr,
# self.mla_indices_updater_prefill.kv_last_page_len,
self.kv_last_page_len[:bs],
max(forward_batch.extend_seq_lens_cpu),
forward_batch.seq_lens_cpu.max().item(),
)
else:
self.indices_updater_prefill.update(
......@@ -231,30 +232,47 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
elif forward_batch.forward_mode.is_target_verify():
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
draft_num = spec_info.draft_token_num
kv_lens = forward_batch.seq_lens + draft_num
kv_lens_sum = forward_batch.seq_lens_sum + draft_num * bs
device = forward_batch.seq_lens.device
qo_indptr = torch.arange(
0,
(1 + bs) * draft_num,
step=draft_num,
dtype=torch.int32,
device=device,
)
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
kv_lens_sum,
dtype=torch.int32,
device=device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
kv_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
kv_indptr,
kv_indices,
qo_indptr,
# self.mla_indices_updater_prefill.kv_last_page_len,
self.kv_last_page_len[:bs],
draft_num,
None,
)
else:
......@@ -271,8 +289,6 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
......@@ -283,25 +299,26 @@ class AiterAttnBackend(AttentionBackend):
extend_no_prefix = False
else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
if self.use_mla:
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_prefix_lens,
sum(forward_batch.extend_prefix_lens_cpu),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
max(forward_batch.extend_seq_lens_cpu),
forward_batch.seq_lens_cpu.max().item(),
spec_info=None,
)
self.mla_indices_updater_prefill.kv_indptr += (
self.mla_indices_updater_prefill.qo_indptr
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
self.kv_last_page_len[:bs],
self.mla_indices_updater_prefill.max_q_len,
self.mla_indices_updater_prefill.max_kv_len,
)
else:
self.indices_updater_prefill.update(
......@@ -317,8 +334,6 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
......@@ -359,7 +374,7 @@ class AiterAttnBackend(AttentionBackend):
if forward_mode.is_decode_or_idle():
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
max_q_len = None
if spec_info is None:
kv_indptr = self.kv_indptr
......@@ -383,17 +398,15 @@ class AiterAttnBackend(AttentionBackend):
qo_indptr[1 : bs + 1] = torch.cumsum(
self.cuda_graph_kv_last_page_len[:bs], dim=0
)
max_extend_len = 1
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = 1
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
max_q_len,
None,
)
......@@ -419,18 +432,15 @@ class AiterAttnBackend(AttentionBackend):
kv_indices,
self.req_to_token.stride(0),
)
max_extend_len = self.num_draft_tokens
kv_last_page_len = None
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = self.num_draft_tokens
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
max_q_len,
None,
)
else:
......@@ -448,12 +458,41 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
elif forward_mode.is_draft_extend():
num_tokens_per_bs = self.speculative_num_steps + 1
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
bs * num_tokens_per_bs + 1,
step=num_tokens_per_bs,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = num_tokens_per_bs
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_q_len,
None,
)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
......@@ -488,13 +527,44 @@ class AiterAttnBackend(AttentionBackend):
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
bs = len(req_pool_indices)
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_lens = seq_lens + self.num_draft_tokens
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
kv_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
elif forward_mode.is_draft_extend():
seq_lens = seq_lens[:bs]
accept_lens = spec_info.accept_length[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
raise ValueError("Invalid forward mode")
......@@ -530,11 +600,10 @@ class AiterAttnBackend(AttentionBackend):
)
if self.use_mla:
max_extend_len = self.forward_metadata.max_extend_len
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
max_q_len = self.forward_metadata.max_q_len
max_kv_len = self.forward_metadata.max_kv_len
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
kv_last_page_lens = self.forward_metadata.kv_last_page_len
qo_indptr = self.forward_metadata.qo_indptr
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
......@@ -552,8 +621,8 @@ class AiterAttnBackend(AttentionBackend):
v,
qo_indptr,
qo_indptr,
max_extend_len,
max_extend_len,
max_q_len,
max_q_len,
softmax_scale=layer.scaling,
causal=True,
)
......@@ -599,12 +668,71 @@ class AiterAttnBackend(AttentionBackend):
v,
qo_indptr,
kv_indptr,
max_extend_len,
max_prefix_extend_len,
max_q_len,
max_kv_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
elif forward_batch.forward_mode.is_target_verify():
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
mla_decode_fwd(
q,
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
o,
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
return o
elif forward_batch.forward_mode.is_draft_extend():
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
causal = True
sliding_window_size = -1
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
mla_prefill_fwd(
q,
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
o,
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
K_Buffer = K_Buffer.view(-1, 1, layer.qk_head_dim)
return o
# self.extend_attention_fwd(
# q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
# k.contiguous(),
# v.contiguous(),
# o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
# forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
# forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
# self.forward_metadata.qo_indptr,
# kv_indptr,
# kv_indices,
# None,
# causal,
# None,
# self.forward_metadata.max_q_len,
# layer.scaling,
# layer.logit_cap,
# sliding_window_size,
# )
# return o
else:
raise ValueError(
f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
)
else:
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
......@@ -662,7 +790,7 @@ class AiterAttnBackend(AttentionBackend):
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_extend_len,
self.forward_metadata.max_q_len,
layer.scaling,
layer.logit_cap,
)
......@@ -816,16 +944,17 @@ class AiterMlaIndicesUpdaterPrefill:
self.kv_indices = None
self.qo_indptr = None
self.kv_last_page_len = None
self.max_extend_len = 0
self.max_prefix_extend_len = 0
self.max_q_len = 0
self.max_kv_len = 0
def update(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
kv_lens: torch.Tensor,
kv_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
......@@ -834,33 +963,30 @@ class AiterMlaIndicesUpdaterPrefill:
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
kv_lens: torch.Tensor,
kv_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo],
):
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = prefix_lens_sum
bs = len(req_pool_indices)
kv_indptr = self.attn_backend.kv_indptr
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
kv_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_lens,
kv_indptr,
None,
kv_indices,
......@@ -870,16 +996,12 @@ class AiterMlaIndicesUpdaterPrefill:
qo_indptr = self.attn_backend.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
max_extend_len = torch.max(extend_lens).item()
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
kv_indptr += qo_indptr
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
kv_lens,
kv_lens_sum,
self.req_to_token,
)
)
......@@ -887,5 +1009,146 @@ class AiterMlaIndicesUpdaterPrefill:
self.kv_indptr = kv_indptr
self.kv_indices = kv_indices
self.qo_indptr = qo_indptr
self.max_extend_len = max_extend_len
self.max_prefix_extend_len = max_prefix_extend_len
self.max_q_len = max_q_len
self.max_kv_len = max_kv_len
class AiterMultiStepDraftBackend:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
AiterAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
)
)
self.max_context_len = self.attn_backends[0].max_context_len
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.device = model_runner.device
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.page_size = model_runner.server_args.page_size
assert self.page_size == 1, "Page size must be 1"
def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
self.page_size,
)
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.empty(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device=self.device,
)
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
......@@ -1722,6 +1722,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or attention_backend_str == "cutlass_mla"
or attention_backend_str == "ascend"
or attention_backend_str == "trtllm_mha"
or attention_backend_str == "aiter"
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = (
......
......@@ -226,6 +226,22 @@ class EAGLEWorker(TpModelWorker):
self.draft_model_runner,
skip_prefill=False,
)
elif self.server_args.attention_backend == "aiter":
from sglang.srt.layers.attention.aiter_backend import (
AiterAttnBackend,
AiterMultiStepDraftBackend,
)
self.draft_attn_backend = AiterMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = AiterAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment