Unverified Commit fc82f5a7 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix cuda graph padding for triton attention backend (#1782)

parent 0089c4bc
...@@ -41,10 +41,6 @@ class AttentionBackend(ABC): ...@@ -41,10 +41,6 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def forward( def forward(
self, self,
q: torch.Tensor, q: torch.Tensor,
......
...@@ -161,9 +161,6 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -161,9 +161,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
): ):
......
...@@ -210,9 +210,6 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -210,9 +210,6 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
) )
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
): ):
......
...@@ -108,9 +108,6 @@ class TritonAttnBackend(AttentionBackend): ...@@ -108,9 +108,6 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
): ):
......
...@@ -38,7 +38,7 @@ class ReqToTokenPool: ...@@ -38,7 +38,7 @@ class ReqToTokenPool:
self.size = size self.size = size
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device self.device = device
self.req_to_token = torch.empty( self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
) )
self.free_slots = list(range(size)) self.free_slots = list(range(size))
......
...@@ -133,11 +133,8 @@ class CudaGraphRunner: ...@@ -133,11 +133,8 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.seq_len_fill_value = 1
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
if self.use_torch_compile: if self.use_torch_compile:
...@@ -290,7 +287,7 @@ class CudaGraphRunner: ...@@ -290,7 +287,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment