Unverified Commit 384d85ba authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Re-introduce `get_cuda_graph_seq_len_fill_value` (#1783)

parent 60597219
...@@ -41,6 +41,10 @@ class AttentionBackend(ABC): ...@@ -41,6 +41,10 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def forward( def forward(
self, self,
q: torch.Tensor, q: torch.Tensor,
......
...@@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
): ):
......
...@@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
) )
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
): ):
......
...@@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
): ):
......
...@@ -134,7 +134,11 @@ class CudaGraphRunner: ...@@ -134,7 +134,11 @@ class CudaGraphRunner:
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
self.seq_len_fill_value = 1 self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0 self.encoder_len_fill_value = 0
if self.use_torch_compile: if self.use_torch_compile:
...@@ -287,7 +291,7 @@ class CudaGraphRunner: ...@@ -287,7 +291,7 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment