"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fe78a8ae2c4c86e53ba73082a587ac86bc87e671"
Unverified Commit 3f0fe08d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Let ModelRunner take InputMetadata as input, instead of ScheduleBatch (#1541)

parent 55b974f9
...@@ -225,14 +225,16 @@ def extend(reqs, model_runner): ...@@ -225,14 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend(model_runner.model_config.vocab_size)
logits_output = model_runner.forward(batch) input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist() next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.prepare_for_decode(input_token_ids)
logits_output = model_runner.forward(batch) input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
next_token_ids = model_runner.sample(logits_output, batch).tolist() next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
......
...@@ -15,7 +15,7 @@ import torch.nn as nn ...@@ -15,7 +15,7 @@ import torch.nn as nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
...@@ -37,9 +37,7 @@ class AttentionBackend(ABC): ...@@ -37,9 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends""" """The base class of attention backends"""
@abstractmethod @abstractmethod
def init_forward_metadata( def init_forward_metadata(self, input_metadata: InputMetadata):
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
raise NotImplementedError() raise NotImplementedError()
...@@ -133,12 +131,11 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -133,12 +131,11 @@ class FlashInferAttnBackend(AttentionBackend):
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_metadata = {} self.cuda_graph_metadata = {}
def init_forward_metadata( def init_forward_metadata(self, input_metadata: InputMetadata):
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
if input_metadata.forward_mode.is_decode(): if input_metadata.forward_mode.is_decode():
prefix_lens = None prefix_lens = None
use_ragged = False use_ragged = False
extend_no_prefix = False
total_num_tokens = None total_num_tokens = None
else: else:
prefix_lens = input_metadata.extend_prefix_lens prefix_lens = input_metadata.extend_prefix_lens
...@@ -152,6 +149,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -152,6 +149,7 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged = True use_ragged = True
total_num_tokens = torch.sum(input_metadata.seq_lens).item() total_num_tokens = torch.sum(input_metadata.seq_lens).item()
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()
update_flashinfer_indices( update_flashinfer_indices(
input_metadata.forward_mode, input_metadata.forward_mode,
...@@ -162,7 +160,12 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -162,7 +160,12 @@ class FlashInferAttnBackend(AttentionBackend):
use_ragged=use_ragged, use_ragged=use_ragged,
) )
self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper) self.forward_metadata = (
use_ragged,
extend_no_prefix,
total_num_tokens,
self.decode_wrapper,
)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indptr = torch.zeros( self.cuda_graph_kv_indptr = torch.zeros(
...@@ -228,7 +231,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -228,7 +231,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.cuda_graph_metadata[bs] = decode_wrapper self.cuda_graph_metadata[bs] = decode_wrapper
self.forward_metadata = (False, None, decode_wrapper) self.forward_metadata = (False, False, None, decode_wrapper)
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, bs: int, req_pool_indices, seq_lens self, bs: int, req_pool_indices, seq_lens
...@@ -254,7 +257,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -254,7 +257,9 @@ class FlashInferAttnBackend(AttentionBackend):
else: else:
prefill_wrapper_paged = self.prefill_wrapper_paged[1] prefill_wrapper_paged = self.prefill_wrapper_paged[1]
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
...@@ -280,7 +285,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -280,7 +285,7 @@ class FlashInferAttnBackend(AttentionBackend):
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
) )
if input_metadata.extend_no_prefix: if extend_no_prefix:
o = o1 o = o1
else: else:
o2, s2 = prefill_wrapper_paged.forward_return_lse( o2, s2 = prefill_wrapper_paged.forward_return_lse(
...@@ -300,7 +305,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -300,7 +305,9 @@ class FlashInferAttnBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)
if isinstance(decode_wrapper, list): if isinstance(decode_wrapper, list):
if layer.sliding_window_size != -1: if layer.sliding_window_size != -1:
...@@ -351,9 +358,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -351,9 +358,7 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_max_seq_len = model_runner.model_config.context_len self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata( def init_forward_metadata(self, input_metadata: InputMetadata):
self, batch: ScheduleBatch, input_metadata: InputMetadata
):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode(): if input_metadata.forward_mode.is_decode():
...@@ -371,7 +376,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -371,7 +376,7 @@ class TritonAttnBackend(AttentionBackend):
max_extend_len = None max_extend_len = None
else: else:
start_loc = attn_logits = max_seq_len = None start_loc = attn_logits = max_seq_len = None
prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") prefix_lens = input_metadata.extend_prefix_lens
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
......
...@@ -18,13 +18,12 @@ limitations under the License. ...@@ -18,13 +18,12 @@ limitations under the License.
import re import re
from dataclasses import dataclass
import torch import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip, replace_submodule from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later # ROCm: flashinfer available later
...@@ -208,9 +207,9 @@ class LoRAManager: ...@@ -208,9 +207,9 @@ class LoRAManager:
if lora_weight_name: if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, batch, extend_seq_lens=None): def prepare_lora_batch(self, input_metadata: InputMetadata):
# load active loras into lora memory pool # load active loras into lora memory pool
cur_uids = set([req.lora_path for req in batch.reqs]) cur_uids = set(input_metadata.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
i = 0 i = 0
evictable_uids = list(self.active_uids) evictable_uids = list(self.active_uids)
...@@ -230,11 +229,15 @@ class LoRAManager: ...@@ -230,11 +229,15 @@ class LoRAManager:
return return
# setup lora in forward modules # setup lora in forward modules
bs = len(batch.reqs) bs = input_metadata.batch_size
seg_lens = extend_seq_lens if batch.forward_mode.is_extend() else torch.ones(bs) seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, req in enumerate(batch.reqs): for i, lora_path in enumerate(input_metadata.lora_paths):
weight_indices[i] = self.buffer_id[req.lora_path] weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules: for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
......
...@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap ...@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -511,6 +511,9 @@ class ScheduleBatch: ...@@ -511,6 +511,9 @@ class ScheduleBatch:
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def get_input_metadata(self):
return InputMetadata.from_schedule_batch(self)
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED self.forward_mode = ForwardMode.MIXED
running_bs = running_batch.batch_size() running_bs = running_batch.batch_size()
......
...@@ -575,8 +575,9 @@ class Scheduler: ...@@ -575,8 +575,9 @@ class Scheduler:
if self.is_generation: if self.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
batch input_metadata, batch
) )
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
...@@ -640,7 +641,8 @@ class Scheduler: ...@@ -640,7 +641,8 @@ class Scheduler:
) )
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
embeddings = self.tp_worker.forward_batch_embedding(batch) input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
...@@ -769,7 +771,10 @@ class Scheduler: ...@@ -769,7 +771,10 @@ class Scheduler:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(batch) input_metadata = batch.get_input_metadata()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
......
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
...@@ -105,13 +106,13 @@ class ModelTpWorker: ...@@ -105,13 +106,13 @@ class ModelTpWorker:
self.random_seed, self.random_seed,
) )
def forward_batch_generation(self, batch): def forward_batch_generation(self, input_metadata: InputMetadata, batch):
logits_output = self.model_runner.forward(batch) logits_output = self.model_runner.forward(input_metadata)
next_token_ids = self.model_runner.sample(logits_output, batch) next_token_ids = self.model_runner.sample(logits_output, batch)
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_embedding(self, batch): def forward_batch_embedding(self, input_metadata: InputMetadata):
logits_output = self.model_runner.forward(batch) logits_output = self.model_runner.forward(input_metadata)
embeddings = logits_output.embeddings.tolist() embeddings = logits_output.embeddings.tolist()
return embeddings return embeddings
......
...@@ -31,7 +31,6 @@ from sglang.srt.layers.logits_processor import ( ...@@ -31,7 +31,6 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor, LogitsProcessor,
LogitsProcessorOutput, LogitsProcessorOutput,
) )
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
...@@ -143,7 +142,6 @@ class CudaGraphRunner: ...@@ -143,7 +142,6 @@ class CudaGraphRunner:
self.seq_lens = torch.full( self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
# Capture # Capture
...@@ -189,7 +187,6 @@ class CudaGraphRunner: ...@@ -189,7 +187,6 @@ class CudaGraphRunner:
input_ids = self.input_ids[:bs] input_ids = self.input_ids[:bs]
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs] out_cache_loc = self.out_cache_loc[:bs]
# Attention backend # Attention backend
...@@ -202,6 +199,7 @@ class CudaGraphRunner: ...@@ -202,6 +199,7 @@ class CudaGraphRunner:
input_metadata = InputMetadata( input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
batch_size=bs, batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
...@@ -210,7 +208,7 @@ class CudaGraphRunner: ...@@ -210,7 +208,7 @@ class CudaGraphRunner:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_logprob=False, return_logprob=False,
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
) )
return forward(input_ids, input_metadata.positions, input_metadata) return forward(input_ids, input_metadata.positions, input_metadata)
...@@ -235,24 +233,22 @@ class CudaGraphRunner: ...@@ -235,24 +233,22 @@ class CudaGraphRunner:
self.graph_memory_pool = graph.pool() self.graph_memory_pool = graph.pool()
return graph, out return graph, out
def replay(self, batch: ScheduleBatch): def replay(self, input_metadata: InputMetadata):
assert batch.out_cache_loc is not None assert input_metadata.out_cache_loc is not None
raw_bs = len(batch.reqs) raw_bs = input_metadata.batch_size
# Pad # Pad
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value) self.seq_lens.fill_(self.seq_len_fill_value)
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
self.input_ids[:raw_bs] = batch.input_ids self.input_ids[:raw_bs] = input_metadata.input_ids
self.req_pool_indices[:raw_bs] = batch.req_pool_indices self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
self.seq_lens[:raw_bs] = batch.seq_lens self.seq_lens[:raw_bs] = input_metadata.seq_lens
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
...@@ -275,15 +271,15 @@ class CudaGraphRunner: ...@@ -275,15 +271,15 @@ class CudaGraphRunner:
) )
# Extract logprobs # Extract logprobs
if batch.return_logprob: if input_metadata.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1 logits_output.next_token_logits, dim=-1
) )
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
logits_metadata = LogitsMetadata( logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=input_metadata.top_logprobs_nums,
) )
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata logits_output.next_token_logprobs, logits_metadata
......
...@@ -18,7 +18,7 @@ limitations under the License. ...@@ -18,7 +18,7 @@ limitations under the License.
"""Meta data for a forward pass.""" """Meta data for a forward pass."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List, Set
import numpy as np import numpy as np
import torch import torch
...@@ -27,7 +27,6 @@ if TYPE_CHECKING: ...@@ -27,7 +27,6 @@ if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -37,7 +36,7 @@ class ForwardMode(IntEnum): ...@@ -37,7 +36,7 @@ class ForwardMode(IntEnum):
EXTEND = auto() EXTEND = auto()
# Decode one token. # Decode one token.
DECODE = auto() DECODE = auto()
# Contains both PREFILL and EXTEND. # Contains both EXTEND and DECODE.
MIXED = auto() MIXED = auto()
def is_prefill(self): def is_prefill(self):
...@@ -57,15 +56,17 @@ class ForwardMode(IntEnum): ...@@ -57,15 +56,17 @@ class ForwardMode(IntEnum):
class InputMetadata: class InputMetadata:
"""Store all inforamtion of a forward pass.""" """Store all inforamtion of a forward pass."""
# The forward mode
forward_mode: ForwardMode forward_mode: ForwardMode
# The batch size
batch_size: int batch_size: int
# The input ids
input_ids: torch.Tensor
# The indices of requests in the req_to_token_pool
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor seq_lens: torch.Tensor
req_to_token_pool: ReqToTokenPool # The indices of output tokens in the token_to_kv_pool
token_to_kv_pool: BaseTokenToKVPool
attn_backend: AttentionBackend
# Output location of the KV cache
out_cache_loc: torch.Tensor out_cache_loc: torch.Tensor
# Position information # Position information
...@@ -75,7 +76,6 @@ class InputMetadata: ...@@ -75,7 +76,6 @@ class InputMetadata:
extend_seq_lens: torch.Tensor = None extend_seq_lens: torch.Tensor = None
extend_prefix_lens: torch.Tensor = None extend_prefix_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None extend_start_loc: torch.Tensor = None
extend_no_prefix: bool = None
# For logprob # For logprob
return_logprob: bool = False return_logprob: bool = False
...@@ -86,82 +86,51 @@ class InputMetadata: ...@@ -86,82 +86,51 @@ class InputMetadata:
# For multimodal # For multimodal
image_inputs: List[ImageInputs] = None image_inputs: List[ImageInputs] = None
def init_multimuldal_info(self, batch: ScheduleBatch): # For LoRA
self.image_inputs = [r.image_inputs for r in batch.reqs] lora_paths: List[str] = None
def compute_positions(self, batch: ScheduleBatch): # Attention backend
if self.forward_mode.is_decode(): req_to_token_pool: ReqToTokenPool = None
if True: token_to_kv_pool: BaseTokenToKVPool = None
self.positions = self.seq_lens - 1 attn_backend: AttentionBackend = None
else:
# Deprecated
self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
else:
if True:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for i, req in enumerate(batch.reqs)
],
axis=0,
),
device="cuda",
)
else:
# Deprecated
position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
self.positions = torch.tensor(
np.concatenate(
[
np.arange(
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
len(req.fill_ids) + position_ids_offsets_cpu[i],
)
for i, req in enumerate(batch.reqs)
],
axis=0,
),
device="cuda",
)
# Positions should be in long type
self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch):
self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
self.extend_seq_lens_cpu = batch.extend_lens_cpu
self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
cls, cls,
model_runner: "ModelRunner",
batch: ScheduleBatch, batch: ScheduleBatch,
): ):
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=batch.batch_size(), batch_size=batch.batch_size(),
input_ids=batch.input_ids,
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
attn_backend=model_runner.attn_backend,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=[req.lora_path for req in batch.reqs],
) )
ret.compute_positions(batch) if ret.forward_mode.is_decode():
ret.positions = (ret.seq_lens - 1).to(torch.int64)
if not batch.forward_mode.is_decode(): else:
ret.init_multimuldal_info(batch) ret.positions = torch.tensor(
ret.compute_extend_infos(batch) np.concatenate(
[
model_runner.attn_backend.init_forward_metadata(batch, ret) np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for i, req in enumerate(batch.reqs)
],
axis=0,
),
device="cuda",
).to(torch.int64)
ret.image_inputs = [r.image_inputs for r in batch.reqs]
ret.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
ret.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
ret.extend_seq_lens_cpu = batch.extend_lens_cpu
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
return ret return ret
...@@ -466,46 +466,47 @@ class ModelRunner: ...@@ -466,46 +466,47 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, batch: ScheduleBatch): def forward_decode(self, input_metadata: InputMetadata):
if self.server_args.lora_paths is not None: if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
self.lora_manager.prepare_lora_batch(batch) input_metadata.batch_size
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): return self.cuda_graph_runner.replay(input_metadata)
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata input_metadata.input_ids, input_metadata.positions, input_metadata
) )
def forward_extend(self, batch: ScheduleBatch): def forward_extend(self, input_metadata: InputMetadata):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
if self.is_generation: if self.is_generation:
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata input_metadata.input_ids, input_metadata.positions, input_metadata
) )
else: else:
# Only embedding models have get_embedding parameter # Only embedding models have get_embedding parameter
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.input_ids,
input_metadata.positions, input_metadata.positions,
input_metadata, input_metadata,
get_embedding=True, get_embedding=True,
) )
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]: def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
assert batch.forward_mode is not None # Attach attention information
input_metadata.req_to_token_pool = self.req_to_token_pool
input_metadata.token_to_kv_pool = self.token_to_kv_pool
input_metadata.attn_backend = self.attn_backend
input_metadata.attn_backend.init_forward_metadata(input_metadata)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(input_metadata)
if batch.forward_mode.is_decode(): if input_metadata.forward_mode.is_decode():
return self.forward_decode(batch) return self.forward_decode(input_metadata)
elif batch.forward_mode.is_extend(): elif input_metadata.forward_mode.is_extend():
return self.forward_extend(batch) return self.forward_extend(input_metadata)
else: else:
raise ValueError(f"Invaid forward mode: {batch.forward_mode}") raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
def _apply_logits_bias( def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
......
...@@ -71,10 +71,10 @@ class ModelOutput: ...@@ -71,10 +71,10 @@ class ModelOutput:
class HFRunner: class HFRunner:
def __init__( def __init__(
self, self,
model_path, model_path: str,
torch_dtype, torch_dtype: torch.dtype,
model_type="generation", model_type: str = "generation",
output_str_only=False, output_str_only: bool = False,
): ):
self.model_type = model_type self.model_type = model_type
self.output_str_only = output_str_only self.output_str_only = output_str_only
...@@ -244,15 +244,15 @@ class HFRunner: ...@@ -244,15 +244,15 @@ class HFRunner:
class SRTRunner: class SRTRunner:
def __init__( def __init__(
self, self,
model_path, model_path: str,
torch_dtype, torch_dtype: torch.dtype,
model_type, model_type: str,
tp_size=1, tp_size: int = 1,
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths=None, lora_paths: List[str] = None,
max_loras_per_batch=4, max_loras_per_batch: int = 4,
disable_cuda_graph=False, disable_cuda_graph: bool = False,
disable_radix_cache=False, disable_radix_cache: bool = False,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
......
...@@ -15,7 +15,6 @@ limitations under the License. ...@@ -15,7 +15,6 @@ limitations under the License.
import multiprocessing as mp import multiprocessing as mp
import unittest import unittest
import uuid
import torch import torch
...@@ -85,9 +84,9 @@ class TestLoRA(unittest.TestCase): ...@@ -85,9 +84,9 @@ class TestLoRA(unittest.TestCase):
with SRTRunner( with SRTRunner(
base_path, base_path,
tp_size=tp_size,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
is_generation=True, model_type="generation",
tp_size=tp_size,
lora_paths=all_lora_paths, lora_paths=all_lora_paths,
max_loras_per_batch=3, max_loras_per_batch=3,
disable_cuda_graph=True, disable_cuda_graph=True,
......
...@@ -7,6 +7,7 @@ suites = { ...@@ -7,6 +7,7 @@ suites = {
"minimal": [ "minimal": [
"models/test_embedding_models.py", "models/test_embedding_models.py",
"models/test_generation_models.py", "models/test_generation_models.py",
# "models/test_lora.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"sampling/penaltylib", "sampling/penaltylib",
"test_chunked_prefill.py", "test_chunked_prefill.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment