Unverified Commit 87e8c090 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize code (rename, movement) (#953)

parent ad56e684
...@@ -50,8 +50,9 @@ import torch ...@@ -50,8 +50,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -188,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -188,7 +189,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
def extend(reqs, model_runner): def extend(reqs, model_runner):
batch = Batch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
......
...@@ -25,7 +25,7 @@ from vllm.distributed import ( ...@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -22,11 +22,8 @@ from torch import nn ...@@ -22,11 +22,8 @@ from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.model_executor.model_runner import ( from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
ForwardMode, from sglang.srt.model_executor.model_runner import global_server_args_dict
InputMetadata,
global_server_args_dict,
)
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
......
...@@ -18,7 +18,6 @@ limitations under the License. ...@@ -18,7 +18,6 @@ limitations under the License.
import logging import logging
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
...@@ -46,15 +45,6 @@ global_server_args_dict = { ...@@ -46,15 +45,6 @@ global_server_args_dict = {
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND = auto()
# Decode one token.
DECODE = auto()
class BaseFinishReason: class BaseFinishReason:
def __init__(self, is_error: bool = False): def __init__(self, is_error: bool = False):
self.is_error = is_error self.is_error = is_error
...@@ -284,7 +274,7 @@ class Req: ...@@ -284,7 +274,7 @@ class Req:
@dataclass @dataclass
class Batch: class ScheduleBatch:
"""Store all inforamtion of a batch.""" """Store all inforamtion of a batch."""
# Request, memory pool, and cache # Request, memory pool, and cache
...@@ -673,7 +663,7 @@ class Batch: ...@@ -673,7 +663,7 @@ class Batch:
if self_val is not None: # logit_bias can be None if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices]) setattr(self, item, self_val[new_indices])
def merge(self, other: "Batch"): def merge(self, other: "ScheduleBatch"):
self.reqs.extend(other.reqs) self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat( self.req_pool_indices = torch.concat(
...@@ -770,229 +760,6 @@ class Batch: ...@@ -770,229 +760,6 @@ class Batch:
return batch_next_token_ids return batch_next_token_ids
@dataclass
class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
req_pool_indices: torch.Tensor
seq_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
# For extend
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor
extend_no_prefix: bool
# Output location of the KV cache
out_cache_loc: torch.Tensor = None
# Output options
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# Trition attention backend
triton_max_seq_len: int = 0
triton_max_extend_len: int = 0
triton_start_loc: torch.Tensor = None
triton_prefix_lens: torch.Tensor = None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False
@classmethod
def create(
cls,
model_runner,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
top_logprobs_nums=None,
return_logprob=False,
skip_flashinfer_init=False,
):
flashinfer_use_ragged = False
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
flashinfer_use_ragged = True
init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged,
)
batch_size = len(req_pool_indices)
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
extend_seq_lens = extend_start_loc = extend_no_prefix = None
if not model_runner.server_args.disable_flashinfer:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens = None
else:
total_num_tokens = int(torch.sum(seq_lens))
else:
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
extend_seq_lens = seq_lens - prefix_lens
extend_start_loc = torch.zeros_like(seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
extend_no_prefix = torch.all(prefix_lens == 0)
total_num_tokens = int(torch.sum(seq_lens))
ret = cls(
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
extend_seq_lens=extend_seq_lens,
extend_start_loc=extend_start_loc,
extend_no_prefix=extend_no_prefix,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged=flashinfer_use_ragged,
)
if model_runner.server_args.disable_flashinfer:
(
ret.triton_max_seq_len,
ret.triton_max_extend_len,
ret.triton_start_loc,
ret.triton_prefix_lens,
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
return ret
def init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
flashinfer_decode_wrapper,
flashinfer_use_ragged=False,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices)
total_num_tokens = int(torch.sum(seq_lens))
if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE:
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
# extend part
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)
# cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
def init_triton_args(forward_mode, seq_lens, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
batch_size = len(seq_lens)
max_seq_len = int(torch.max(seq_lens))
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
if forward_mode == ForwardMode.DECODE:
max_extend_len = None
else:
extend_seq_lens = seq_lens - prefix_lens
max_extend_len = int(torch.max(extend_seq_lens))
return max_seq_len, max_extend_len, start_loc, prefix_lens
def top_k_top_p_sampling_from_probs_torch( def top_k_top_p_sampling_from_probs_torch(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
): ):
......
...@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler ...@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
BaseFinishReason, BaseFinishReason,
Batch,
ForwardMode,
Req, Req,
ScheduleBatch,
) )
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -172,7 +172,7 @@ class ModelTpServer: ...@@ -172,7 +172,7 @@ class ModelTpServer:
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
self.running_batch: Batch = None self.running_batch: ScheduleBatch = None
self.out_pyobjs = [] self.out_pyobjs = []
self.decode_forward_ct = 0 self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
...@@ -353,7 +353,7 @@ class ModelTpServer: ...@@ -353,7 +353,7 @@ class ModelTpServer:
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]: def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
# TODO(lsyin): organize this function # TODO(lsyin): organize this function
running_bs = ( running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0 len(self.running_batch.reqs) if self.running_batch is not None else 0
...@@ -526,7 +526,7 @@ class ModelTpServer: ...@@ -526,7 +526,7 @@ class ModelTpServer:
) )
# Return the new batch # Return the new batch
new_batch = Batch.init_new( new_batch = ScheduleBatch.init_new(
can_run_list, can_run_list,
self.req_to_token_pool, self.req_to_token_pool,
self.token_to_kv_pool, self.token_to_kv_pool,
...@@ -535,7 +535,7 @@ class ModelTpServer: ...@@ -535,7 +535,7 @@ class ModelTpServer:
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
return new_batch return new_batch
def forward_prefill_batch(self, batch: Batch): def forward_prefill_batch(self, batch: ScheduleBatch):
# Build batch tensors # Build batch tensors
batch.prepare_for_extend( batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias self.model_config.vocab_size, self.int_token_logit_bias
...@@ -624,7 +624,7 @@ class ModelTpServer: ...@@ -624,7 +624,7 @@ class ModelTpServer:
) )
req.output_top_logprobs.append(output.output_top_logprobs[i]) req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: Batch): def cache_filled_batch(self, batch: ScheduleBatch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req( new_prefix_indices, new_last_node = self.tree_cache.cache_req(
...@@ -641,7 +641,7 @@ class ModelTpServer: ...@@ -641,7 +641,7 @@ class ModelTpServer:
# inflight request would get a new req idx # inflight request would get a new req idx
self.req_to_token_pool.free(int(req_pool_indices_cpu[i])) self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
def forward_decode_batch(self, batch: Batch): def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
...@@ -700,7 +700,7 @@ class ModelTpServer: ...@@ -700,7 +700,7 @@ class ModelTpServer:
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: Batch): def handle_finished_requests(self, batch: ScheduleBatch):
output_rids = [] output_rids = []
output_vids = [] output_vids = []
decoded_texts = [] decoded_texts = []
...@@ -800,7 +800,7 @@ class ModelTpServer: ...@@ -800,7 +800,7 @@ class ModelTpServer:
else: else:
batch.reqs = [] batch.reqs = []
def filter_out_inflight(self, batch: Batch): def filter_out_inflight(self, batch: ScheduleBatch):
# TODO(lsyin): reduce the overhead, make a special version for this # TODO(lsyin): reduce the overhead, make a special version for this
if self.current_inflight_req is None: if self.current_inflight_req is None:
return return
......
...@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import ( ...@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import ScheduleBatch
Batch, from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
init_flashinfer_args, init_flashinfer_args,
...@@ -202,7 +202,7 @@ class CudaGraphRunner: ...@@ -202,7 +202,7 @@ class CudaGraphRunner:
self.graph_memory_pool = graph.pool() self.graph_memory_pool = graph.pool()
return graph, None, out, flashinfer_decode_wrapper return graph, None, out, flashinfer_decode_wrapper
def replay(self, batch: Batch): def replay(self, batch: ScheduleBatch):
assert batch.out_cache_loc is not None assert batch.out_cache_loc is not None
raw_bs = len(batch.reqs) raw_bs = len(batch.reqs)
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
import numpy as np
import torch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
EXTEND = auto()
# Decode one token.
DECODE = auto()
@dataclass
class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode: ForwardMode
batch_size: int
total_num_tokens: int
req_pool_indices: torch.Tensor
seq_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
# For extend
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor
extend_no_prefix: bool
# Output location of the KV cache
out_cache_loc: torch.Tensor = None
# Output options
return_logprob: bool = False
top_logprobs_nums: List[int] = None
# Trition attention backend
triton_max_seq_len: int = 0
triton_max_extend_len: int = 0
triton_start_loc: torch.Tensor = None
triton_prefix_lens: torch.Tensor = None
# FlashInfer attention backend
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False
@classmethod
def create(
cls,
model_runner,
forward_mode,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
top_logprobs_nums=None,
return_logprob=False,
skip_flashinfer_init=False,
):
flashinfer_use_ragged = False
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
flashinfer_use_ragged = True
init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged,
)
batch_size = len(req_pool_indices)
if forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
extend_seq_lens = extend_start_loc = extend_no_prefix = None
if not model_runner.server_args.disable_flashinfer:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens = None
else:
total_num_tokens = int(torch.sum(seq_lens))
else:
seq_lens_cpu = seq_lens.cpu().numpy()
prefix_lens_cpu = prefix_lens.cpu().numpy()
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
positions = torch.tensor(
np.concatenate(
[
np.arange(
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
)
for i in range(batch_size)
],
axis=0,
),
device="cuda",
)
extend_seq_lens = seq_lens - prefix_lens
extend_start_loc = torch.zeros_like(seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
extend_no_prefix = torch.all(prefix_lens == 0)
total_num_tokens = int(torch.sum(seq_lens))
ret = cls(
forward_mode=forward_mode,
batch_size=batch_size,
total_num_tokens=total_num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc,
extend_seq_lens=extend_seq_lens,
extend_start_loc=extend_start_loc,
extend_no_prefix=extend_no_prefix,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged=flashinfer_use_ragged,
)
if model_runner.server_args.disable_flashinfer:
(
ret.triton_max_seq_len,
ret.triton_max_extend_len,
ret.triton_start_loc,
ret.triton_prefix_lens,
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
return ret
def init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
flashinfer_decode_wrapper,
flashinfer_use_ragged=False,
):
"""Init auxiliary variables for FlashInfer attention backend."""
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices)
total_num_tokens = int(torch.sum(seq_lens))
if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens
else:
paged_kernel_lens = seq_lens
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE:
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
else:
# extend part
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
if flashinfer_use_ragged:
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
qo_indptr,
qo_indptr,
num_qo_heads,
num_kv_heads,
head_dim,
)
# cached part
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
1,
)
def init_triton_args(forward_mode, seq_lens, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
batch_size = len(seq_lens)
max_seq_len = int(torch.max(seq_lens))
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
if forward_mode == ForwardMode.DECODE:
max_extend_len = None
else:
extend_seq_lens = seq_lens - prefix_lens
max_extend_len = int(torch.max(extend_seq_lens))
return max_seq_len, max_extend_len, start_loc, prefix_lens
...@@ -41,18 +41,14 @@ from vllm.distributed import ( ...@@ -41,18 +41,14 @@ from vllm.distributed import (
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_config import AttentionArch from sglang.srt.model_config import AttentionArch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
...@@ -350,7 +346,7 @@ class ModelRunner: ...@@ -350,7 +346,7 @@ class ModelRunner:
) )
@torch.inference_mode() @torch.inference_mode()
def forward_decode(self, batch: Batch): def forward_decode(self, batch: ScheduleBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch) return self.cuda_graph_runner.replay(batch)
...@@ -370,7 +366,7 @@ class ModelRunner: ...@@ -370,7 +366,7 @@ class ModelRunner:
) )
@torch.inference_mode() @torch.inference_mode()
def forward_extend(self, batch: Batch): def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
...@@ -387,7 +383,7 @@ class ModelRunner: ...@@ -387,7 +383,7 @@ class ModelRunner:
) )
@torch.inference_mode() @torch.inference_mode()
def forward_extend_multi_modal(self, batch: Batch): def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
...@@ -408,7 +404,7 @@ class ModelRunner: ...@@ -408,7 +404,7 @@ class ModelRunner:
batch.image_offsets, batch.image_offsets,
) )
def forward(self, batch: Batch, forward_mode: ForwardMode): def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
return self.forward_extend_multi_modal(batch) return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE: elif forward_mode == ForwardMode.DECODE:
......
...@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig ...@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
LoraConfig = None LoraConfig = None
......
...@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
@torch.compile @torch.compile
......
...@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig ...@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):
......
...@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
......
...@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
......
...@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
......
...@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
......
...@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once ...@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
from sglang.srt.layers.fused_moe import fused_moe from sglang.srt.layers.fused_moe import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
use_fused = True use_fused = True
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
......
...@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
......
...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel from sglang.srt.models.llama2 import LlamaModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment