Unverified Commit 09603c6d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Maintain seq_lens_sum to make more FlashInfer operations non-blocking (#1741)

parent cf470fea
......@@ -621,7 +621,6 @@ Please cite our paper, [SGLang: Efficient Execution of Structured Language Model
We also learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
<p align="center">
<a href="#sglangtop" target="_blank">
<bold>Back To Top </bold>
......
......@@ -25,7 +25,11 @@ class AttentionBackend(ABC):
raise NotImplementedError()
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
):
"""Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError()
......
......@@ -144,7 +144,11 @@ class DoubleSparseAttnBackend(AttentionBackend):
)
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
):
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
......
......@@ -127,6 +127,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_decode.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
)
self.forward_metadata = (self.decode_wrappers,)
else:
......@@ -134,10 +135,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Some heuristics to check whether to use ragged forward
use_ragged = False
if (
torch.sum(forward_batch.seq_lens).item() >= 4096
and self.num_wrappers == 1
):
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
use_ragged = True
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
......@@ -181,15 +179,25 @@ class FlashInferAttnBackend(AttentionBackend):
)
)
self.indices_updater_decode.update(req_pool_indices, seq_lens, decode_wrappers)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers
)
self.cuda_graph_metadata[bs] = decode_wrappers
self.forward_metadata = (decode_wrappers,)
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
):
self.indices_updater_decode.update(
req_pool_indices[:bs], seq_lens[:bs], self.cuda_graph_metadata[bs]
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
self.cuda_graph_metadata[bs],
)
def get_cuda_graph_seq_len_fill_value(self):
......@@ -305,13 +313,30 @@ class FlashInferIndicesUpdaterDecode:
assert attn_backend.num_wrappers == 1
self.update = self.update_single_wrapper
def update_single_wrapper(self, req_pool_indices, seq_lens, decode_wrappers=None):
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward(
decode_wrappers[0], req_pool_indices, seq_lens, self.kv_indptr[0], None
decode_wrappers[0],
req_pool_indices,
seq_lens,
seq_lens_sum,
self.kv_indptr[0],
None,
)
def update_sliding_window(self, req_pool_indices, seq_lens, decode_wrappers=None):
def update_sliding_window(
self,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
decode_wrappers=None,
):
decode_wrappers = decode_wrappers or self.decode_wrappers
for wrapper_id in range(2):
......@@ -331,6 +356,7 @@ class FlashInferIndicesUpdaterDecode:
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
self.kv_indptr[wrapper_id],
kv_start_idx,
)
......@@ -339,13 +365,18 @@ class FlashInferIndicesUpdaterDecode:
raise NotImplementedError()
def call_begin_forward(
self, wrapper, req_pool_indices, paged_kernel_lens, kv_indptr, kv_start_idx
self,
wrapper,
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
kv_indptr,
kv_start_idx,
):
bs = len(req_pool_indices)
kv_indptr = kv_indptr[: bs + 1]
# TODO: optimize the blocking call on kv_indptr[-1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
......
......@@ -91,7 +91,11 @@ class TritonAttnBackend(AttentionBackend):
)
def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
):
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
......
......@@ -416,7 +416,6 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
......@@ -424,9 +423,13 @@ class ScheduleBatch:
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None
# The sum of all sequence lengths
seq_lens_sum: int = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
......@@ -435,7 +438,6 @@ class ScheduleBatch:
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
running_bs: int = None
decoding_reqs: List[Req] = None
# Stream
......@@ -549,10 +551,12 @@ class ScheduleBatch:
self.device, non_blocking=True
)
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
......@@ -571,12 +575,11 @@ class ScheduleBatch:
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_bs
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.extend_num_tokens += running_bs
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
......@@ -775,6 +778,7 @@ class ScheduleBatch:
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
self.seq_lens_sum += bs
def filter_batch(
self,
......@@ -805,6 +809,7 @@ class ScheduleBatch:
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[new_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
......@@ -828,6 +833,7 @@ class ScheduleBatch:
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
......@@ -873,9 +879,11 @@ class ScheduleBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
seq_lens_sum=self.seq_lens_sum,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
......@@ -917,6 +925,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# The sum of all sequence lengths
seq_lens_sum: int
# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
......@@ -925,6 +936,7 @@ class ModelWorkerBatch:
top_logprobs_nums: Optional[List[int]]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]
......
......@@ -188,6 +188,7 @@ class CudaGraphRunner:
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs]
seq_lens_sum = seq_lens.sum().item()
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
......@@ -206,6 +207,7 @@ class CudaGraphRunner:
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
......@@ -252,7 +254,10 @@ class CudaGraphRunner:
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, self.req_pool_indices, self.seq_lens
bs,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum,
)
# Replay
......
......@@ -87,6 +87,9 @@ class ForwardBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# The sum of all sequence lengths
seq_lens_sum: int
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
......@@ -95,6 +98,7 @@ class ForwardBatch:
positions: torch.Tensor = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
......@@ -175,21 +179,6 @@ class ForwardBatch:
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
device = model_runner.device
if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64)
else:
self.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
@classmethod
def init_new(
cls,
......@@ -205,6 +194,7 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=batch.lora_paths,
......@@ -213,7 +203,17 @@ class ForwardBatch:
# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.image_inputs = batch.image_inputs
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
......@@ -225,12 +225,8 @@ class ForwardBatch:
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init position information
is_mrope = model_runner.model_is_mrope
if is_mrope:
if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch)
else:
ret.compute_positions(model_runner, batch)
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment