"vscode:/vscode.git/clone" did not exist on "92ea5baca2815ecd51f96bedb0fb766b313196f8"
Unverified Commit 2d611323 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support Eagle2 for Triton backend (#3466)

parent cddb1cdf
...@@ -3,6 +3,7 @@ from __future__ import annotations ...@@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
import triton
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
...@@ -18,7 +19,12 @@ if TYPE_CHECKING: ...@@ -18,7 +19,12 @@ if TYPE_CHECKING:
class TritonAttnBackend(AttentionBackend): class TritonAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner): def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
):
# Lazy import to avoid the initialization of cuda context # Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.decode_attention import ( from sglang.srt.layers.attention.triton_ops.decode_attention import (
decode_attention_fwd, decode_attention_fwd,
...@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend): ...@@ -33,14 +39,25 @@ class TritonAttnBackend(AttentionBackend):
self.extend_attention_fwd = extend_attention_fwd self.extend_attention_fwd = extend_attention_fwd
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device if kv_indptr_buf is None:
) self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
else:
self.kv_indptr = kv_indptr_buf
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.qo_indptr = torch.zeros( self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device (max_bs + 1,), dtype=torch.int32, device=model_runner.device
) )
self.mask_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.num_head = ( self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
...@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -50,7 +67,7 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
...@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend): ...@@ -59,11 +76,31 @@ class TritonAttnBackend(AttentionBackend):
bs = forward_batch.batch_size bs = forward_batch.batch_size
kv_indptr = self.kv_indptr kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if forward_batch.forward_mode.is_decode():
attn_logits = torch.empty( if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.zeros(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
attn_logits = torch.zeros(
( (
forward_batch.batch_size, bs,
self.num_head, self.num_head,
self.num_kv_splits, self.num_kv_splits,
self.v_head_dim + 1, self.v_head_dim + 1,
...@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend): ...@@ -72,12 +109,24 @@ class TritonAttnBackend(AttentionBackend):
device=self.device, device=self.device,
) )
qo_indptr = None
custom_mask = None
mask_indptr = None
max_extend_len = None max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
# Different with flashinfer kv_indptr and kv_indices construction
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.zeros(
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device kv_indptr[-1], dtype=torch.int32, device=self.device
) )
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
...@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend): ...@@ -89,15 +138,32 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token.stride(0), self.req_to_token.stride(0),
) )
qo_indptr = None custom_mask = spec_info.custom_mask
custom_mask = None seq_mask_len = self.num_draft_tokens * (
mask_offsets = None forward_batch.seq_lens + self.num_draft_tokens
)
mask_indptr = self.mask_indptr
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
mask_indptr = mask_indptr[: bs + 1]
max_extend_len = self.num_draft_tokens
attn_logits = None
elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
self.req_to_token,
)
)
mask_indptr = None
max_extend_len = torch.max(spec_info.accept_length).item()
attn_logits = None
else: else:
kv_indptr[1 : bs + 1] = torch.cumsum( kv_indptr[1 : bs + 1] = torch.cumsum(
forward_batch.extend_prefix_lens, dim=0 forward_batch.extend_prefix_lens, dim=0
) )
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty( kv_indices = torch.zeros(
forward_batch.extend_prefix_lens.sum().item(), forward_batch.extend_prefix_lens.sum().item(),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
...@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -116,8 +182,7 @@ class TritonAttnBackend(AttentionBackend):
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None custom_mask = None
mask_offsets = None mask_indptr = None
attn_logits = None attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item() max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
...@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend): ...@@ -128,22 +193,22 @@ class TritonAttnBackend(AttentionBackend):
kv_indices, kv_indices,
qo_indptr, qo_indptr,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
) )
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len
self.cuda_graph_start_loc = torch.zeros( self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device=self.device (max_bs,), dtype=torch.int32, device=self.device
) )
self.cuda_graph_attn_logits = torch.empty( self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.cuda_graph_max_seq_len), (max_bs * self.max_context_len),
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
...@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -244,8 +309,9 @@ class TritonAttnBackend(AttentionBackend):
kv_indices, kv_indices,
qo_indptr, qo_indptr,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
) = self.forward_metadata ) = self.forward_metadata
self.extend_attention_fwd( self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(), k.contiguous(),
...@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -257,7 +323,7 @@ class TritonAttnBackend(AttentionBackend):
kv_indptr, kv_indptr,
kv_indices, kv_indices,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
max_extend_len, max_extend_len,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
...@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend): ...@@ -303,3 +369,136 @@ class TritonAttnBackend(AttentionBackend):
layer.logit_cap, layer.logit_cap,
) )
return o return o
class TritonMultiStepDraftBackend:
"""
Wrap multiple triton attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
TritonAttnBackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
)
)
self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
):
num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk)
](
forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens,
kv_indices_buffer,
self.kv_indptr,
forward_batch.positions,
num_seqs,
self.topk,
self.pool_len,
kv_indices_buffer.shape[1],
self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs),
)
for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1)
]
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)
forward_batch.spec_info.kv_indices = (
forward_batch.spec_info.kv_indices.clone()
)
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
device="cuda",
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
forward_batch.batch_size,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
...@@ -50,7 +50,7 @@ def _fwd_kernel( ...@@ -50,7 +50,7 @@ def _fwd_kernel(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
mask_ptr, mask_ptr,
mask_offsets, mask_indptr,
sm_scale, sm_scale,
kv_group_num, kv_group_num,
stride_qbs, stride_qbs,
...@@ -87,7 +87,7 @@ def _fwd_kernel( ...@@ -87,7 +87,7 @@ def _fwd_kernel(
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
if USE_CUSTOM_MASK: if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq) cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV) offs_dv = tl.arange(0, BLOCK_DV)
...@@ -288,7 +288,7 @@ def extend_attention_fwd( ...@@ -288,7 +288,7 @@ def extend_attention_fwd(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
max_len_extend, max_len_extend,
sm_scale=None, sm_scale=None,
logit_cap=0.0, logit_cap=0.0,
...@@ -364,7 +364,7 @@ def extend_attention_fwd( ...@@ -364,7 +364,7 @@ def extend_attention_fwd(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
sm_scale, sm_scale,
kv_group_num, kv_group_num,
q_extend.stride(0), q_extend.stride(0),
......
...@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker): ...@@ -65,15 +65,31 @@ class EAGLEWorker(TpModelWorker):
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
from sglang.srt.layers.attention.flashinfer_backend import ( if server_args.attention_backend == "flashinfer":
FlashInferMultiStepDraftBackend, from sglang.srt.layers.attention.flashinfer_backend import (
) FlashInferMultiStepDraftBackend,
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
elif server_args.attention_backend == "triton":
from sglang.srt.layers.attention.triton_backend import (
TritonMultiStepDraftBackend,
)
self.draft_attn_backend = TritonMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
else:
raise ValueError(
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
)
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
self.model_runner,
self.topk,
self.speculative_num_steps,
)
self.model_runner.draft_attn_backend = self.draft_attn_backend self.model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs() self.init_cuda_graphs()
......
...@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.20) self.assertGreater(metrics["accuracy"], 0.20)
class TestEAGLEServerTriton(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"8",
"--speculative-num-draft-tokens",
"64",
"--mem-fraction-static",
"0.7",
"--attention-backend",
"triton",
# TODO: Support cuda graph
"--disable-cuda-graph",
],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
custom_mask = None custom_mask = None
mask_offsets = None mask_indptr = None
extend_attention_fwd( extend_attention_fwd(
q_extend, q_extend,
...@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr, kv_indptr,
kv_indices, kv_indices,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
max_len_extend, max_len_extend,
) )
...@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
custom_mask = torch.ones( custom_mask = torch.ones(
(b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda" (b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
) )
mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda") mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0) mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
for i in range(B): for i in range(B):
causal_mask = ( causal_mask = (
torch.tril( torch.tril(
...@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
) )
mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten() mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten custom_mask[mask_indptr[i] : mask_indptr[i + 1]] = mask_flatten
extend_attention_fwd( extend_attention_fwd(
q_extend, q_extend,
...@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase): ...@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr, kv_indptr,
kv_indices, kv_indices,
custom_mask, custom_mask,
mask_offsets, mask_indptr,
max_len_extend, max_len_extend,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment