Unverified Commit a07364cc authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Update Triton decode backend interface (#3292)

parent 2c1a695f
...@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional ...@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
...@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd self.extend_attention_fwd = extend_attention_fwd
max_bs = model_runner.req_to_token_pool.size
self.kv_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.num_head = ( self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
...@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend): ...@@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend):
) )
max_extend_len = None max_extend_len = None
kv_indptr = self.kv_indptr
bs = len(forward_batch.req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
kv_indptr,
None,
kv_indices,
forward_batch.req_to_token_pool.req_to_token.stride(0),
)
else: else:
attn_logits = None attn_logits = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item() max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = attn_logits, max_extend_len kv_indptr = None
kv_indices = None
self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
...@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend): ...@@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_attn_logits = torch.empty( self.cuda_graph_attn_logits = torch.empty(
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
dtype=torch.float32, dtype=torch.float32,
device="cuda", device=self.device,
)
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.cuda_graph_max_seq_len),
dtype=torch.int32,
device=self.device,
) )
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
...@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend): ...@@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend):
assert forward_mode.is_decode(), "Not supported" assert forward_mode.is_decode(), "Not supported"
assert spec_info is None, "Not supported" assert spec_info is None, "Not supported"
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_attn_logits, self.cuda_graph_attn_logits,
None, None,
kv_indptr,
kv_indices,
) )
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
...@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend): ...@@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens[:bs],
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
...@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
_, max_extend_len = self.forward_metadata _, max_extend_len, _, _ = self.forward_metadata
self.extend_attention_fwd( self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(), k.contiguous(),
...@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend):
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
attn_logits, _ = self.forward_metadata attn_logits, _, kv_indptr, kv_indices = self.forward_metadata
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
...@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend): ...@@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.req_to_token_pool.req_to_token, kv_indptr,
forward_batch.req_pool_indices, kv_indices,
forward_batch.seq_lens,
attn_logits, attn_logits,
self.num_kv_splits, self.num_kv_splits,
layer.scaling, layer.scaling,
......
...@@ -49,11 +49,9 @@ def _fwd_kernel_stage1( ...@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
K_Buffer, K_Buffer,
V_Buffer, V_Buffer,
sm_scale, sm_scale,
Req_to_tokens, kv_indptr,
B_req_idx, kv_indices,
B_Seqlen,
Att_Out, Att_Out,
stride_req_to_tokens_b,
stride_qbs, stride_qbs,
stride_qh, stride_qh,
stride_buf_kbs, stride_buf_kbs,
...@@ -82,8 +80,9 @@ def _fwd_kernel_stage1( ...@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
offs_dv = tl.arange(0, BLOCK_DV) offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk mask_d = offs_d < Lk
mask_dv = offs_dv < Lv mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q, mask=mask_d, other=0.0) q = tl.load(Q + off_q, mask=mask_d, other=0.0)
...@@ -100,7 +99,7 @@ def _fwd_kernel_stage1( ...@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
for start_n in range(split_kv_start, split_kv_end, BLOCK_N): for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N) offs_n = start_n + tl.arange(0, BLOCK_N)
kv_loc = tl.load( kv_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, kv_indices + cur_batch_kv_start_idx + offs_n,
mask=offs_n < split_kv_end, mask=offs_n < split_kv_end,
other=0, other=0,
) )
...@@ -173,9 +172,8 @@ def _decode_att_m_fwd( ...@@ -173,9 +172,8 @@ def _decode_att_m_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
att_out, att_out,
Req_to_tokens, kv_indptr,
B_req_idx, kv_indices,
B_Seqlen,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
...@@ -188,7 +186,7 @@ def _decode_att_m_fwd( ...@@ -188,7 +186,7 @@ def _decode_att_m_fwd(
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
grid = (batch, head_num, NUM_KV_SPLITS) grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1]
...@@ -208,11 +206,9 @@ def _decode_att_m_fwd( ...@@ -208,11 +206,9 @@ def _decode_att_m_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
sm_scale, sm_scale,
Req_to_tokens, kv_indptr,
B_req_idx, kv_indices,
B_Seqlen,
att_out, att_out,
Req_to_tokens.stride(0),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
k_buffer.stride(0), k_buffer.stride(0),
...@@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1( ...@@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
K_Buffer, K_Buffer,
V_Buffer, V_Buffer,
sm_scale, sm_scale,
Req_to_tokens, kv_indptr,
B_req_idx, kv_indices,
B_Seqlen,
Att_Out, Att_Out,
stride_req_to_tokens_b,
stride_qbs, stride_qbs,
stride_qh, stride_qh,
stride_buf_kbs, stride_buf_kbs,
...@@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1( ...@@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
offs_dv = tl.arange(0, BLOCK_DV) offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk mask_d = offs_d < Lk
mask_dv = offs_dv < Lv mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
...@@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1( ...@@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
for start_n in range(split_kv_start, split_kv_end, BLOCK_N): for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N) offs_n = start_n + tl.arange(0, BLOCK_N)
kv_loc = tl.load( kv_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, kv_indices + cur_batch_kv_start_idx + offs_n,
mask=offs_n < split_kv_end, mask=offs_n < split_kv_end,
other=0, other=0,
) )
...@@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd( ...@@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
att_out, att_out,
Req_to_tokens, kv_indptr,
B_req_idx, kv_indices,
B_Seqlen,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
...@@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd( ...@@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_DPE = 0 BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1]
BLOCK_H = 16 BLOCK_H = 16
...@@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd( ...@@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
sm_scale, sm_scale,
Req_to_tokens, kv_indptr,
B_req_idx, kv_indices,
B_Seqlen,
att_out, att_out,
Req_to_tokens.stride(0),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
k_buffer.stride(0), k_buffer.stride(0),
...@@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd( ...@@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd(
def _fwd_kernel_stage2( def _fwd_kernel_stage2(
Mid_O, Mid_O,
O, O,
B_Seqlen, kv_indptr,
stride_mid_ob, stride_mid_ob,
stride_mid_oh, stride_mid_oh,
stride_mid_os, stride_mid_os,
...@@ -498,7 +490,9 @@ def _fwd_kernel_stage2( ...@@ -498,7 +490,9 @@ def _fwd_kernel_stage2(
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
kv_indptr + cur_batch
)
offs_d = tl.arange(0, BLOCK_DV) offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv mask_d = offs_d < Lv
...@@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd( ...@@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd(
q, q,
o, o,
v_buffer, v_buffer,
b_seq_len, kv_indptr,
num_kv_splits, num_kv_splits,
): ):
batch, head_num = q.shape[0], q.shape[1] batch, head_num = q.shape[0], q.shape[1]
...@@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd( ...@@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd(
_fwd_kernel_stage2[grid]( _fwd_kernel_stage2[grid](
logits, logits,
o, o,
b_seq_len, kv_indptr,
logits.stride(0), logits.stride(0),
logits.stride(1), logits.stride(1),
logits.stride(2), logits.stride(2),
...@@ -581,9 +575,8 @@ def decode_attention_fwd_normal( ...@@ -581,9 +575,8 @@ def decode_attention_fwd_normal(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
...@@ -594,14 +587,13 @@ def decode_attention_fwd_normal( ...@@ -594,14 +587,13 @@ def decode_attention_fwd_normal(
k_buffer, k_buffer,
v_buffer, v_buffer,
attn_logits, attn_logits,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
def decode_attention_fwd_grouped( def decode_attention_fwd_grouped(
...@@ -609,9 +601,8 @@ def decode_attention_fwd_grouped( ...@@ -609,9 +601,8 @@ def decode_attention_fwd_grouped(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
...@@ -622,14 +613,13 @@ def decode_attention_fwd_grouped( ...@@ -622,14 +613,13 @@ def decode_attention_fwd_grouped(
k_buffer, k_buffer,
v_buffer, v_buffer,
attn_logits, attn_logits,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
logit_cap, logit_cap,
) )
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
def decode_attention_fwd( def decode_attention_fwd(
...@@ -637,9 +627,8 @@ def decode_attention_fwd( ...@@ -637,9 +627,8 @@ def decode_attention_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
...@@ -655,9 +644,8 @@ def decode_attention_fwd( ...@@ -655,9 +644,8 @@ def decode_attention_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
...@@ -670,9 +658,8 @@ def decode_attention_fwd( ...@@ -670,9 +658,8 @@ def decode_attention_fwd(
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
......
...@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase): ...@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
# o will have the same shape as q # o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda")
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D + 1), (B, H_Q, num_kv_splits, D + 1),
dtype=torch.float32, dtype=torch.float32,
...@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
...@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase): ...@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda")
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits = torch.empty( attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1), (B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32, dtype=torch.float32,
...@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer, k_buffer,
v_buffer, v_buffer,
o, o,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits, attn_logits,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
...@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase): ...@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer, k_buffer,
v_buffer, v_buffer,
o_grouped, o_grouped,
req_to_token, kv_indptr,
b_req_idx, kv_indices,
b_seq_len,
attn_logits1, attn_logits1,
num_kv_splits, num_kv_splits,
sm_scale, sm_scale,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment