Unverified Commit 061e5463 authored by Shuo Yang's avatar Shuo Yang Committed by GitHub
Browse files

Support double sparsity (#1459)

parent 0c1e8796
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
class DoubleSparseAttnBackend(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
# Lazy import to avoid the initialization of cuda context
from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
flash_decode_attention_fwd,
flash_decode_sparse_attention_fwd,
)
from sglang.srt.layers.attention.triton_ops.extend_attention import (
extend_attention_fwd,
)
super().__init__()
self.decode_attention_fwd = flash_decode_attention_fwd
self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
self.extend_attention_fwd = extend_attention_fwd
self.num_head = model_runner.model_config.num_attention_heads
self.head_dim = model_runner.model_config.hidden_size // self.num_head
self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
self.sorted_channels = model_runner.sorted_channels
self.sparse_decode_thresold = (
model_runner.server_args.ds_sparse_decode_threshold
)
self.att_out_approx: torch.Tensor = None
self.mid_out: torch.Tensor = None
self.mid_o_logexpsum: torch.Tensor = None
# TODO: Change the hard-coded block_seq_num
self.BLOCK_SEQ = 128
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32
else:
self.reduce_dtype = torch.float16
self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
if forward_batch.forward_mode.is_decode():
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
device="cuda",
)
max_seq_len = torch.max(forward_batch.seq_lens).item()
min_seq_len = torch.min(forward_batch.seq_lens).item()
max_extend_len = None
# NOTE: Align sequence order with req_to_token order
ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices
]
bsz = forward_batch.seq_lens.shape[0]
att_out_approx = torch.empty(
[self.num_head, bsz, max_seq_len],
dtype=self.reduce_dtype,
device="cuda",
)
block_seq_num = (
self.heavy_token_num + self.BLOCK_SEQ - 1
) // self.BLOCK_SEQ
mid_out = torch.empty(
[bsz, self.num_head, block_seq_num, self.head_dim],
dtype=torch.float32,
device="cuda",
)
mid_o_logexpsum = torch.empty(
[bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
)
self.att_out_approx = att_out_approx
self.mid_out = mid_out
self.mid_o_logexpsum = mid_o_logexpsum
else:
start_loc = attn_logits = max_seq_len = min_seq_len = None
prefix_lens = forward_batch.extend_prefix_lens
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
ds_req_to_token = None
self.forward_metadata = (
start_loc,
attn_logits,
max_seq_len,
min_seq_len,
max_extend_len,
ds_req_to_token,
)
def init_cuda_graph_state(self, max_bs: int):
# TODO(Andy): Support CUDA graph for double sparse attention
raise ValueError(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
device="cuda",
)
def init_forward_metadata_capture_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens
):
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
k_label = torch.gather(
k,
2,
self.sorted_channels[layer.layer_id]
.unsqueeze(0)
.expand(k.shape[0], -1, -1),
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
)
(
start_loc,
attn_logits,
max_seq_len,
min_seq_len,
max_extend_len,
ds_req_to_token,
) = self.forward_metadata
self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
# TODO: Add min seqlen
(
start_loc,
attn_logits,
max_seq_len,
min_seq_len,
max_extend_len,
ds_req_to_token,
) = self.forward_metadata
k_label = torch.gather(
k,
2,
self.sorted_channels[layer.layer_id]
.unsqueeze(0)
.expand(k.shape[0], -1, -1),
)
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
)
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
# and set a minimum value for sparse_decode
if (
min_seq_len < self.heavy_token_num
or max_seq_len < self.sparse_decode_thresold
):
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
start_loc,
forward_batch.seq_lens,
attn_logits,
max_seq_len,
layer.scaling,
layer.logit_cap,
)
else:
# TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
q_label = torch.gather(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
2,
self.sorted_channels[layer.layer_id]
.unsqueeze(0)
.expand(q.shape[0], -1, -1),
)
self.decode_sparse_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
q_label,
forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
ds_req_to_token,
forward_batch.seq_lens,
max_seq_len,
layer.scaling,
layer.logit_cap,
self.heavy_token_num,
self.att_out_approx,
self.mid_out,
self.mid_o_logexpsum,
self.BLOCK_SEQ,
)
return o
import torch
import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:
REDUCE_TRITON_TYPE = tl.float16
REDUCE_TORCH_TYPE = torch.float16
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel_flash_decode_stage1(
Q,
K,
V,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Seqlen,
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, # [batch, head, seq_block_num]
stride_req_to_tokens_b,
stride_req_to_tokens_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_od,
stride_mid_o_eb,
stride_mid_o_eh,
stride_mid_o_es,
gqa_group_size,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
seq_start_block = tl.program_id(2)
cur_kv_head = cur_head // gqa_group_size
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = seq_start_block * BLOCK_SEQ
cur_batch_end_index = tl.minimum(
cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ
)
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
block_n_size = (
tl.where(
cur_batch_end_index - cur_batch_start_index <= 0,
0,
cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,
)
// BLOCK_N
)
offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
q = tl.load(Q + off_q)
sum_exp = 0.0
max_logic = -float("inf")
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, block_n_size, 1):
offs_n_new = start_n * BLOCK_N + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]
k = tl.load(
K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf"))
v = tl.load(
V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
)
cur_max_logic = tl.max(att_value, axis=0)
new_max_logic = tl.maximum(cur_max_logic, max_logic)
exp_logic = tl.exp(att_value - new_max_logic)
logic_scale = tl.exp(max_logic - new_max_logic)
acc *= logic_scale
acc += tl.sum(exp_logic[:, None] * v, axis=0)
sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)
max_logic = new_max_logic
need_store = tl.where(block_n_size == 0, 0, 1)
for _ in range(0, need_store, 1):
off_mid_o = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ seq_start_block * stride_mid_os
+ offs_d
)
off_mid_o_logexpsum = (
cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block
)
tl.store(Mid_O + off_mid_o, acc / sum_exp)
tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))
return
@triton.jit
def _fwd_kernel_flash_decode_stage2(
B_Seqlen,
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, # [batch, head, seq_block_num]
O, # [batch, head, head_dim]
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_od,
stride_mid_o_eb,
stride_mid_o_eh,
stride_mid_o_es,
stride_obs,
stride_oh,
stride_od,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
block_n_size = (
tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1)
// BLOCK_SEQ
)
sum_exp = 0.0
max_logic = -float("inf")
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh
for block_seq_n in range(0, block_n_size, 1):
tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)
tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)
new_max_logic = tl.maximum(tlogic, max_logic)
old_scale = tl.exp(max_logic - new_max_logic)
acc *= old_scale
exp_logic = tl.exp(tlogic - new_max_logic)
acc += exp_logic * tv
sum_exp = sum_exp * old_scale + exp_logic
max_logic = new_max_logic
tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)
return
@torch.no_grad()
def flash_decode_stage1(
q,
k,
v,
Req_to_tokens,
B_req_idx,
B_Seqlen,
max_len_in_batch,
mid_out,
mid_out_logsumexp,
block_seq,
):
BLOCK_SEQ = block_seq
BLOCK_N = 16
assert BLOCK_SEQ % BLOCK_N == 0
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lk**0.5)
batch, head_num = B_req_idx.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))
gqa_group_size = q.shape[1] // k.shape[1]
_fwd_kernel_flash_decode_stage1[grid](
q,
k,
v,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Seqlen,
mid_out,
mid_out_logsumexp,
Req_to_tokens.stride(0),
Req_to_tokens.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
mid_out.stride(0),
mid_out.stride(1),
mid_out.stride(2),
mid_out.stride(3),
mid_out_logsumexp.stride(0),
mid_out_logsumexp.stride(1),
mid_out_logsumexp.stride(2),
gqa_group_size,
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK_N,
num_warps=1,
num_stages=2,
)
return
@torch.no_grad()
def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
Lk = mid_out.shape[-1]
assert Lk in {16, 32, 64, 128}
batch, head_num = mid_out.shape[0], mid_out.shape[1]
grid = (batch, head_num)
_fwd_kernel_flash_decode_stage2[grid](
B_Seqlen,
mid_out,
mid_out_logexpsum,
O,
mid_out.stride(0),
mid_out.stride(1),
mid_out.stride(2),
mid_out.stride(3),
mid_out_logexpsum.stride(0),
mid_out_logexpsum.stride(1),
mid_out_logexpsum.stride(2),
O.stride(0),
O.stride(1),
O.stride(2),
BLOCK_SEQ=block_seq,
BLOCK_DMODEL=Lk,
num_warps=4,
num_stages=2,
)
return
import torch
def flash_decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap=0.0,
):
BLOCK_SEQ = 256
kv_group_num = q.shape[1] // v_buffer.shape[1]
# batch_size = q.shape[0]
block_seq_num = (max_len_in_batch + BLOCK_SEQ - 1) // BLOCK_SEQ
mid_o = torch.empty(
[q.shape[0], q.shape[1], block_seq_num, q.shape[-1]],
dtype=torch.float32,
device="cuda",
)
mid_o_logexpsum = torch.empty(
[q.shape[0], q.shape[1], block_seq_num], dtype=torch.float32, device="cuda"
)
flash_decode_stage1(
q,
k_buffer,
v_buffer,
req_to_token,
b_req_idx,
b_seq_len,
max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ,
)
flash_decode_stage2(mid_o, mid_o_logexpsum, b_seq_len, o, BLOCK_SEQ)
@triton.jit
def _sparse_fwd_kernel_flash_decode_stage1( # Double Sparsity's approximate attention
Q_Label,
K_Label_Buffer,
sm_scale,
Req_to_tokens, # shape: [B, S]
B_Seqlen,
Att_Out, # shape: [H, B, S] easier for topk
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
att_stride_h,
att_stride_b,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_n = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_start_index = 0
cur_batch_end_index = cur_batch_seq_len
min_val = -float("inf")
att_value = tl.full([BLOCK_N], min_val, dtype=tl.float32)
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_index = start_n * BLOCK_N
block_mask = tl.where(block_index < cur_batch_seq_len, 1, 0)
for start_mark in range(0, block_mask, 1):
q = tl.load(Q_Label + off_q + start_mark).to(REDUCE_TRITON_TYPE)
offs_n_new = cur_batch_start_index + offs_n
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
offs_buf_k = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[None, :]
)
k = tl.load(
K_Label_Buffer + offs_buf_k,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
if logit_cap > 0:
att_value = logit_cap * tanh(att_value / logit_cap)
att_value = tl.where(offs_n < cur_batch_end_index, att_value, min_val)
off_o = cur_head * att_stride_h + (cur_batch * att_stride_b + offs_n)
tl.store(Att_Out + off_o, att_value)
@triton.jit
def _sparse_fwd_kernel_flash_decode_stage2(
Q,
K,
V,
sm_scale,
Req_to_tokens, # shape: [B, S]
Topk_token_indices, # shape: [H, B, k]
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, # [batch, head, seq_block_num]
Heavy_token_num, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future
stride_req_to_tokens_b,
stride_topk_token_indices_h,
stride_topk_token_indices_b,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_o_eb,
stride_mid_o_eh,
gqa_group_size,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
seq_start_block = tl.program_id(2)
cur_kv_head = cur_head // gqa_group_size
offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_start_index = seq_start_block * BLOCK_SEQ
cur_batch_end_index = tl.minimum(Heavy_token_num, cur_batch_start_index + BLOCK_SEQ)
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
block_n_size = (
tl.where(
cur_batch_end_index - cur_batch_start_index <= 0,
0,
cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,
)
// BLOCK_N
)
# offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
offs_n = tl.arange(0, BLOCK_N)
q = tl.load(Q + off_q)
sum_exp = 0.0
max_logic = -float("inf")
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(cur_batch_start_index, cur_batch_end_index, BLOCK_N):
# for start_n in range(0, block_n_size, 1):
# offs_n_new = start_n * BLOCK_N + offs_n
offs_n_new = start_n + offs_n
# offs_n_new = cur_batch_start_index + start_n * BLOCK_N + offs_n
topk_token_indices = tl.load(
Topk_token_indices
+ stride_topk_token_indices_h * cur_head
+ stride_topk_token_indices_b * cur_batch
+ offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch + topk_token_indices,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]
k = tl.load(
K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
)
att_value = tl.sum(q[None, :] * k, 1)
att_value *= sm_scale
att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf"))
v = tl.load(
V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
)
cur_max_logic = tl.max(att_value, axis=0)
new_max_logic = tl.maximum(cur_max_logic, max_logic)
exp_logic = tl.exp(att_value - new_max_logic)
logic_scale = tl.exp(max_logic - new_max_logic)
acc *= logic_scale
acc += tl.sum(exp_logic[:, None] * v, axis=0)
sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)
max_logic = new_max_logic
# need_store = tl.where(block_n_size == 0, 0, 1)
need_store = 1
for _ in range(0, need_store, 1):
off_mid_o = (
cur_batch * stride_mid_ob
+ cur_head * stride_mid_oh
+ seq_start_block * stride_mid_os
+ offs_d
)
off_mid_o_logexpsum = (
cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block
)
tl.store(Mid_O + off_mid_o, acc / sum_exp)
tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))
return
@triton.jit
def _sparse_fwd_kernel_flash_decode_stage3(
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, # [batch, head, seq_block_num]
O, # [batch, head, head_dim]
seq_len, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_o_eb,
stride_mid_o_eh,
stride_obs,
stride_oh,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
offs_d = tl.arange(0, BLOCK_DMODEL)
block_n_size = tl.where(seq_len <= 0, 0, seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ
sum_exp = 0.0
max_logic = -float("inf")
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh
for block_seq_n in range(0, block_n_size, 1):
tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)
tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)
new_max_logic = tl.maximum(tlogic, max_logic)
old_scale = tl.exp(max_logic - new_max_logic)
acc *= old_scale
exp_logic = tl.exp(tlogic - new_max_logic)
acc += exp_logic * tv
sum_exp = sum_exp * old_scale + exp_logic
max_logic = new_max_logic
tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)
return
def sparse_flash_decode_stage1(
q_label,
k_label_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
max_len_in_batch,
sm_scale,
logit_cap,
):
BLOCK = 32
# shape constraints
Lq, Lk = q_label.shape[-1], k_label_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256, 576}
BLOCK_DMODEL = Lk
batch, head_num = q_label.shape[0], q_label.shape[1]
grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
kv_group_num = q_label.shape[1] // k_label_buffer.shape[1]
if kv_group_num == 1:
num_warps = 4
else:
num_warps = 2
_sparse_fwd_kernel_flash_decode_stage1[grid](
q_label,
k_label_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q_label.stride(0),
q_label.stride(1),
k_label_buffer.stride(0),
k_label_buffer.stride(1),
att_out.stride(0),
att_out.stride(1),
kv_group_num,
BLOCK_DMODEL,
BLOCK,
logit_cap,
num_warps=num_warps,
num_stages=1,
)
@torch.no_grad()
def sparse_flash_decode_stage2(
q,
k,
v,
Req_to_tokens,
Topk_token_indices,
heavy_token_num,
mid_out,
mid_out_logsumexp,
block_seq,
sm_scale,
):
BLOCK_SEQ = block_seq
BLOCK_N = 16
assert BLOCK_SEQ % BLOCK_N == 0
# shape constraints
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128}
assert heavy_token_num == Topk_token_indices.shape[-1]
# sm_scale = 1.0 / (Lk ** 0.5)
batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, triton.cdiv(heavy_token_num, BLOCK_SEQ))
gqa_group_size = q.shape[1] // k.shape[1]
_sparse_fwd_kernel_flash_decode_stage2[grid](
q,
k,
v,
sm_scale,
Req_to_tokens,
Topk_token_indices,
mid_out,
mid_out_logsumexp,
heavy_token_num,
Req_to_tokens.stride(0),
Topk_token_indices.stride(0),
Topk_token_indices.stride(1),
q.stride(0),
q.stride(1),
k.stride(0),
k.stride(1),
v.stride(0),
v.stride(1),
mid_out.stride(0),
mid_out.stride(1),
mid_out.stride(2),
mid_out_logsumexp.stride(0),
mid_out_logsumexp.stride(1),
gqa_group_size,
BLOCK_SEQ=BLOCK_SEQ,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK_N,
num_warps=1,
num_stages=2,
)
return
@torch.no_grad()
def sparse_flash_decode_stage3(Seqlen, mid_out, mid_out_logexpsum, O, block_seq):
Lk = mid_out.shape[-1]
assert Lk in {16, 32, 64, 128}
batch, head_num = mid_out.shape[0], mid_out.shape[1]
grid = (batch, head_num)
_sparse_fwd_kernel_flash_decode_stage3[grid](
mid_out,
mid_out_logexpsum,
O,
Seqlen,
mid_out.stride(0),
mid_out.stride(1),
mid_out.stride(2),
mid_out_logexpsum.stride(0),
mid_out_logexpsum.stride(1),
O.stride(0),
O.stride(1),
BLOCK_SEQ=block_seq,
BLOCK_DMODEL=Lk,
num_warps=4,
num_stages=2,
)
return
def flash_decode_sparse_attention_fwd(
q,
k_buffer,
v_buffer,
o,
q_label,
k_label_buffer,
req_to_token,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
heavy_token_num=32,
att_out_approx=None,
mid_out=None,
mid_o_logexpsum=None,
BLOCK_SEQ=256,
):
# TODO(Andy): Tune BLOCK_SEQ & BLOCK_D
kv_group_num = q.shape[1] // v_buffer.shape[1]
# batch_size = q.shape[0]
# Step 1: BGEMV approximate attention (page implementation)
if att_out_approx is None:
att_out_approx = torch.empty(
[q.shape[1], q.shape[0], max_len_in_batch],
dtype=REDUCE_TORCH_TYPE,
device=q.device,
)
if mid_out is None:
block_seq_num = (heavy_token_num + BLOCK_SEQ - 1) // BLOCK_SEQ
mid_out = torch.empty(
[q.shape[0], q.shape[1], block_seq_num, q.shape[-1]],
dtype=torch.float32,
device=q.device,
)
mid_o_logexpsum = torch.empty(
[q.shape[0], q.shape[1], block_seq_num],
dtype=torch.float32,
device=q.device,
)
sparse_flash_decode_stage1(
q_label,
k_label_buffer,
att_out_approx,
req_to_token,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
# Step 2: TopK token selection
# NOTE(Andy): Apply sparse decoding when min > heavy_token_num and max > sparse decoding threshold
# TODO(Andy): Change a faster topk implementation
topk_token_indices = torch.topk(att_out_approx, heavy_token_num, dim=-1).indices
# topk_token_indices: [H, B, k], Req_to_tokens: [B, S]
# topk_token_indices = torch.arange(0, heavy_token_num, device=q.device).unsqueeze(0).unsqueeze(0).expand(q.shape[1], q.shape[0], -1)
sparse_flash_decode_stage2(
q,
k_buffer,
v_buffer,
req_to_token,
topk_token_indices,
heavy_token_num,
mid_out,
mid_o_logexpsum,
BLOCK_SEQ,
sm_scale,
)
sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
......@@ -231,3 +231,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
else:
self.kv_buffer[layer_id][loc] = cache_k
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
heavy_channel_num: int,
):
super().__init__(size, dtype, device)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_label_buffer(self, layer_id: int):
return self.label_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
def set_kv_buffer(
self,
layer_id: int,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
cache_label: torch.Tensor,
):
# NOTE(Andy): ignore the dtype check
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label
......@@ -18,6 +18,7 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import json
import logging
import pkgutil
from functools import lru_cache
......@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.constrained import disable_cache
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
......@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool,
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
......@@ -99,6 +102,20 @@ class ModelRunner:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
)
self.server_args.attention_backend = "triton"
self.server_args.disable_cuda_graph = True
if self.server_args.ds_heavy_channel_type is None:
raise ValueError(
"Please specify the heavy channel type for double sparsity optimization."
)
self.init_double_sparsity_channel_config(
self.server_args.ds_heavy_channel_type
)
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
......@@ -439,6 +456,16 @@ class ModelRunner:
layer_num=self.model_config.num_hidden_layers,
device=self.device,
)
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
heavy_channel_num=self.server_args.ds_heavy_channel_num,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
......@@ -475,12 +502,33 @@ class ModelRunner:
"Cross attention is not supported in the triton attention backend. "
"Please use `--attention-backend flashinfer`."
)
self.attn_backend = TritonAttnBackend(self)
if self.server_args.enable_double_sparsity:
self.attn_backend = DoubleSparseAttnBackend(self)
else:
self.attn_backend = TritonAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"
)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"
self.sorted_channels = []
# load channel config
with open(self.server_args.ds_channel_config_path, "r") as f:
channel_config = json.load(f)
for i in range(self.model_config.num_hidden_layers):
key = "model.layers." + str(i) + ".self_attn" + selected_channel
self.sorted_channels.append(
torch.tensor(channel_config[key])[
:, : self.server_args.ds_heavy_channel_num
]
.contiguous()
.cuda()
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
......
......@@ -86,6 +86,14 @@ class ServerArgs:
# Model override args in JSON
json_model_override_args: str = "{}"
# Double Sparsity
enable_double_sparsity: bool = False
ds_channel_config_path: str = None
ds_heavy_channel_num: int = 32
ds_heavy_token_num: int = 256
ds_heavy_channel_type: str = "qk"
ds_sparse_decode_threshold: int = 4096
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
......@@ -443,6 +451,43 @@ class ServerArgs:
default=ServerArgs.json_model_override_args,
)
# Double Sparsity
parser.add_argument(
"--enable-double-sparsity",
action="store_true",
help="Enable double sparsity attention",
)
parser.add_argument(
"--ds-channel-config-path",
type=str,
default=ServerArgs.ds_channel_config_path,
help="The path of the double sparsity channel config",
)
parser.add_argument(
"--ds-heavy-channel-num",
type=int,
default=ServerArgs.ds_heavy_channel_num,
help="The number of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-token-num",
type=int,
default=ServerArgs.ds_heavy_token_num,
help="The number of heavy tokens in double sparsity attention",
)
parser.add_argument(
"--ds-heavy-channel-type",
type=str,
default=ServerArgs.ds_heavy_channel_type,
help="The type of heavy channels in double sparsity attention",
)
parser.add_argument(
"--ds-sparse-decode-threshold",
type=int,
default=ServerArgs.ds_sparse_decode_threshold,
help="The type of heavy channels in double sparsity attention",
)
# LoRA
parser.add_argument(
"--lora-paths",
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -11,6 +11,7 @@ suites = {
"models/test_reward_models.py",
"sampling/penaltylib",
"test_chunked_prefill.py",
"test_double_sparsity.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_json_constrained.py",
......
import os
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestDoubleSparsity(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, "Llama-3.1-8B-Instruct.json")
# NOTE: Generate the config file by running https://github.com/andy-yang-1/DoubleSparse/blob/main/evaluation/group_channel_config.py
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-double-sparsity",
"--ds-channel-config-path",
config_file,
"--ds-heavy-channel-num",
"32",
"--ds-heavy-channel-type",
"k",
"--ds-heavy-token-num",
"512",
"--ds-sparse-decode-threshold",
"0",
"--max-total-tokens",
"200000",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment