"vscode:/vscode.git/clone" did not exist on "98111fbe3ebd429258923ae00c3e1c7b1be8dcec"
Unverified Commit 519e20cf authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Code clean up: Remove deprecated prefill move InputMetadata to infer_batch.py (#609)

parent d9a69029
......@@ -8,6 +8,7 @@ from torch import nn
from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
......@@ -29,8 +30,6 @@ class RadixAttention(nn.Module):
self.scaling = scaling
self.layer_id = layer_id
from sglang.srt.managers.controller.model_runner import global_server_args_dict
if not global_server_args_dict.get("disable_flashinfer", False):
self.prefill_forward = self.prefill_forward_flashinfer
self.extend_forward = self.prefill_forward_flashinfer
......@@ -141,9 +140,7 @@ class RadixAttention(nn.Module):
k = k.view(-1, self.tp_k_head_num, self.head_dim)
v = v.view(-1, self.tp_v_head_num, self.head_dim)
if input_metadata.forward_mode == ForwardMode.PREFILL:
return self.prefill_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.EXTEND:
if input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.decode_forward(q, k, v, input_metadata)
......
......@@ -15,10 +15,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Store some global server args
global_server_args_dict = {}
class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND = auto()
# Decode one token.
DECODE = auto()
......@@ -66,6 +72,8 @@ class FINISH_ABORT(BaseFinishReason):
class Req:
"""Store all inforamtion of a request."""
def __init__(self, rid, origin_input_text, origin_input_ids):
self.rid = rid
self.origin_input_text = origin_input_text
......@@ -74,7 +82,7 @@ class Req:
self.output_ids = [] # Each decode stage's output ids
self.input_ids = None # input_ids = origin_input_ids + output_ids
# For incremental decode
# For incremental decoding
self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
......@@ -93,9 +101,8 @@ class Req:
self.sampling_params = None
self.stream = False
self.tokenizer = None
# Check finish
self.tokenizer = None
self.finished_reason = None
# Prefix info
......@@ -252,6 +259,8 @@ class Req:
@dataclass
class Batch:
"""Store all inforamtion of a batch."""
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
......@@ -692,3 +701,203 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor
] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
return probs_sort, probs_idx
@dataclass
class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for flashinfer
qo_indptr: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if (
self.forward_mode == ForwardMode.EXTEND
):
paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0)
else:
paged_kernel_lens = self.seq_lens
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
if self.forward_mode == ForwardMode.EXTEND:
# extend part
self.qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward(
self.qo_indptr,
self.qo_indptr.clone(),
num_qo_heads,
num_kv_heads,
head_dim,
)
# cached part
self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward(
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
flashinfer_prefill_wrapper_ragged=None,
flashinfer_prefill_wrapper_paged=None,
flashinfer_decode_wrapper=None,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
other_kv_index = None
ret = cls(
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.head_dim,
)
return ret
......@@ -4,11 +4,9 @@ import importlib
import importlib.resources
import logging
import pkgutil
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Type
from typing import Optional, Type
import numpy as np
import torch
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
......@@ -17,7 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
......@@ -29,210 +27,6 @@ from sglang.srt.utils import (
logger = logging.getLogger("srt.model_runner")
# for server args in model endpoints
global_server_args_dict = {}
@dataclass
class InputMetadata:
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
max_seq_len: int
req_pool_indices: torch.Tensor
start_loc: torch.Tensor
seq_lens: torch.Tensor
prefix_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
# for extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
max_extend_len: int = 0
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# for flashinfer
qo_indptr: torch.Tensor = None
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
kv_last_page_len: torch.Tensor = None
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
):
paged_kernel_lens = self.prefix_lens
self.no_prefix = torch.all(self.prefix_lens == 0)
else:
paged_kernel_lens = self.seq_lens
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
self.kv_last_page_len = torch.ones(
(self.batch_size,), dtype=torch.int32, device="cuda"
)
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
self.kv_indices = torch.cat(
[
self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(self.batch_size)
],
dim=0,
).contiguous()
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
):
# extend part
self.qo_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.flashinfer_prefill_wrapper_ragged.end_forward()
self.flashinfer_prefill_wrapper_ragged.begin_forward(
self.qo_indptr,
self.qo_indptr.clone(),
num_qo_heads,
num_kv_heads,
head_dim,
)
# cached part
self.flashinfer_prefill_wrapper_paged.end_forward()
self.flashinfer_prefill_wrapper_paged.begin_forward(
self.qo_indptr,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
self.flashinfer_decode_wrapper.end_forward()
self.flashinfer_decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
def init_extend_args(self):
self.extend_seq_lens = self.seq_lens - self.prefix_lens
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.max_extend_len = int(torch.max(self.extend_seq_lens))
@classmethod
def create(
cls,
model_runner,
tp_size,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
flashinfer_prefill_wrapper_ragged=None,
flashinfer_prefill_wrapper_paged=None,
flashinfer_decode_wrapper=None,
):
batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
total_num_tokens = int(torch.sum(seq_lens))
max_seq_len = int(torch.max(seq_lens))
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
other_kv_index = model_runner.req_to_token_pool.req_to_token[
req_pool_indices[0], seq_lens[0] - 1
].item()
else:
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
other_kv_index = None
ret = cls(
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
max_seq_len=max_seq_len,
req_pool_indices=req_pool_indices,
start_loc=start_loc,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
other_kv_index=other_kv_index,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
)
if forward_mode == ForwardMode.EXTEND:
ret.init_extend_args()
if not global_server_args_dict.get("disable_flashinfer", False):
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.head_dim,
)
return ret
class ModelRunner:
def __init__(
......@@ -245,6 +39,7 @@ class ModelRunner:
nccl_port: int,
server_args: ServerArgs,
):
# Parse args
self.model_config = model_config
self.mem_fraction_static = mem_fraction_static
self.gpu_id = gpu_id
......@@ -256,7 +51,6 @@ class ModelRunner:
monkey_patch_vllm_dummy_weight_loader()
# Init torch distributed
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
......@@ -287,11 +81,8 @@ class ModelRunner:
)
# Set some global args
global global_server_args_dict
global_server_args_dict = {
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
# Load the model and create memory pool
self.load_model()
......@@ -425,27 +216,6 @@ class ModelRunner:
) = None
self.flashinfer_decode_wrapper = None
@torch.inference_mode()
def forward_prefill(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
@torch.inference_mode()
def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
......@@ -523,8 +293,6 @@ class ModelRunner:
return self.forward_decode(batch)
elif forward_mode == ForwardMode.EXTEND:
return self.forward_extend(batch)
elif forward_mode == ForwardMode.PREFILL:
return self.forward_prefill(batch)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment