Unverified Commit 36d5acfc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename InputMetadata -> ForwardBatch (#1543)

parent 3f0fe08d
......@@ -30,6 +30,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
- Remove `Sample`.
- Change `forward()` functions, and add `input_metadata`.
- Change `forward()` functions, and add `forward_batch`.
- Add `EntryClass` at the end.
......@@ -225,16 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
forward_batch = batch.get_forward_batch()
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
input_metadata = batch.get_input_metadata()
logits_output = model_runner.forward(input_metadata)
forward_batch = batch.get_forward_batch()
logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits
......
......@@ -16,7 +16,7 @@ import torch.nn as nn
from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
if TYPE_CHECKING:
......@@ -37,7 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends"""
@abstractmethod
def init_forward_metadata(self, input_metadata: InputMetadata):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass."""
raise NotImplementedError()
......@@ -61,18 +61,18 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run forward on an attention layer."""
if input_metadata.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, forward_batch)
else:
return self.forward_extend(q, k, v, layer, input_metadata)
return self.forward_extend(q, k, v, layer, forward_batch)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run a forward for decode."""
raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run a forward for extend."""
raise NotImplementedError()
......@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend):
self.forward_metadata = None
self.cuda_graph_metadata = {}
def init_forward_metadata(self, input_metadata: InputMetadata):
if input_metadata.forward_mode.is_decode():
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode():
prefix_lens = None
use_ragged = False
extend_no_prefix = False
total_num_tokens = None
else:
prefix_lens = input_metadata.extend_prefix_lens
prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward
use_ragged = False
if (
torch.sum(input_metadata.seq_lens).item() >= 4096
torch.sum(forward_batch.seq_lens).item() >= 4096
and self.model_runner.sliding_window_size is None
):
use_ragged = True
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item()
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
update_flashinfer_indices(
input_metadata.forward_mode,
forward_batch.forward_mode,
self.model_runner,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
prefix_lens,
use_ragged=use_ragged,
)
......@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 0
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged
else:
......@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend):
if not use_ragged:
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
......@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
else:
o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False,
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
......@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata
)
......@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap,
)
......@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(self, input_metadata: InputMetadata):
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode():
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0)
if forward_batch.forward_mode.is_decode():
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item()
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
device="cuda",
)
max_seq_len = torch.max(input_metadata.seq_lens).item()
max_seq_len = torch.max(forward_batch.seq_lens).item()
max_extend_len = None
else:
start_loc = attn_logits = max_seq_len = None
prefix_lens = input_metadata.extend_prefix_lens
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item()
prefix_lens = forward_batch.extend_prefix_lens
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
......@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else:
o = torch.empty_like(q)
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
......@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend):
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.seq_lens,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.extend_seq_lens,
forward_batch.extend_start_loc,
max_extend_len,
layer.scaling,
layer.logit_cap,
)
return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
......@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v
forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, forward_batch.out_cache_loc, k, v
)
self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token,
forward_batch.req_pool_indices,
start_loc,
input_metadata.seq_lens,
forward_batch.seq_lens,
attn_logits,
max_seq_len,
layer.scaling,
......
......@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@dataclasses.dataclass
......@@ -61,26 +61,26 @@ class LogitsMetadata:
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if input_metadata.forward_mode.is_extend():
def from_forward_batch(cls, forward_batch: ForwardBatch):
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [
extend_len - start_len
for extend_len, start_len in zip(
input_metadata.extend_seq_lens,
input_metadata.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens,
forward_batch.extend_logprob_start_lens_cpu,
)
]
else:
extend_logprob_pruned_lens_cpu = None
return cls(
forward_mode=input_metadata.forward_mode,
top_logprobs_nums=input_metadata.top_logprobs_nums,
return_logprob=input_metadata.return_logprob,
forward_mode=forward_batch.forward_mode,
top_logprobs_nums=forward_batch.top_logprobs_nums,
return_logprob=forward_batch.return_logprob,
return_top_logprob=return_top_logprob,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
)
......@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module):
input_ids,
hidden_states,
weight,
logits_metadata: Union[LogitsMetadata, InputMetadata],
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
if isinstance(logits_metadata, InputMetadata):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction
......
......@@ -7,7 +7,7 @@ from enum import IntEnum
import torch
import torch.nn as nn
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import ForwardBatch
class PoolingType(IntEnum):
......@@ -36,10 +36,10 @@ class Pooler(nn.Module):
self.normalize = normalize
def forward(
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
......
......@@ -17,7 +17,7 @@ limitations under the License.
from torch import nn
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class RadixAttention(nn.Module):
......@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
def forward(self, q, k, v, input_metadata: InputMetadata):
def forward(self, q, k, v, forward_batch: ForwardBatch):
if k is not None:
# For cross-layer sharing, kv can be None
assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
......@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
class BaseLayerWithLoRA(nn.Module):
......
......@@ -23,7 +23,7 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
......@@ -207,9 +207,9 @@ class LoRAManager:
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, input_metadata: InputMetadata):
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(input_metadata.lora_paths)
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
evictable_uids = list(self.active_uids)
......@@ -229,14 +229,14 @@ class LoRAManager:
return
# setup lora in forward modules
bs = input_metadata.batch_size
bs = forward_batch.batch_size
seg_lens = (
input_metadata.extend_seq_lens
if input_metadata.forward_mode.is_extend()
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs)
)
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(input_metadata.lora_paths):
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules:
......
......@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
......@@ -511,8 +511,8 @@ class ScheduleBatch:
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def get_input_metadata(self):
return InputMetadata.from_schedule_batch(self)
def get_forward_batch(self):
return ForwardBatch.from_schedule_batch(self)
def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED
......
......@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
class SchedulerPolicy:
class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
......
......@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
)
from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy
from sglang.srt.managers.tp_worker import ModelTpWorker
from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs
......@@ -134,7 +134,7 @@ class Scheduler:
)
# Launch a tensor parallel worker
self.tp_worker = ModelTpWorker(
self.tp_worker = TpModelWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
......@@ -179,7 +179,7 @@ class Scheduler:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache)
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status
self.waiting_queue: List[Req] = []
......@@ -575,9 +575,9 @@ class Scheduler:
if self.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata()
forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
forward_batch, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
......@@ -641,8 +641,8 @@ class Scheduler:
)
else:
assert batch.extend_num_tokens != 0
input_metadata = batch.get_input_metadata()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata)
forward_batch = batch.get_forward_batch()
embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
# Check finish conditions
for i, req in enumerate(batch.reqs):
......@@ -771,9 +771,9 @@ class Scheduler:
batch.prepare_for_decode()
# Forward and sample the next tokens
input_metadata = batch.get_input_metadata()
forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch
forward_batch, batch
)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
......
......@@ -21,7 +21,7 @@ import logging
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
......@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se
logger = logging.getLogger(__name__)
class ModelTpWorker:
class TpModelWorker:
"""A tensor parallel model worker."""
def __init__(
self,
gpu_id: int,
......@@ -106,13 +108,13 @@ class ModelTpWorker:
self.random_seed,
)
def forward_batch_generation(self, input_metadata: InputMetadata, batch):
logits_output = self.model_runner.forward(input_metadata)
def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
logits_output = self.model_runner.forward(forward_batch)
next_token_ids = self.model_runner.sample(logits_output, batch)
return logits_output, next_token_ids
def forward_batch_embedding(self, input_metadata: InputMetadata):
logits_output = self.model_runner.forward(input_metadata)
def forward_batch_embedding(self, forward_batch: ForwardBatch):
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist()
return embeddings
......
......@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
......@@ -196,7 +196,7 @@ class CudaGraphRunner:
# Run and capture
def run_once():
input_metadata = InputMetadata(
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
......@@ -210,7 +210,7 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
)
return forward(input_ids, input_metadata.positions, input_metadata)
return forward(input_ids, forward_batch.positions, forward_batch)
for _ in range(2):
torch.cuda.synchronize()
......@@ -233,9 +233,9 @@ class CudaGraphRunner:
self.graph_memory_pool = graph.pool()
return graph, out
def replay(self, input_metadata: InputMetadata):
assert input_metadata.out_cache_loc is not None
raw_bs = input_metadata.batch_size
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size
# Pad
index = bisect.bisect_left(self.capture_bs, raw_bs)
......@@ -245,10 +245,10 @@ class CudaGraphRunner:
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:raw_bs] = input_metadata.input_ids
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices
self.seq_lens[:raw_bs] = input_metadata.seq_lens
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc
self.input_ids[:raw_bs] = forward_batch.input_ids
self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
self.seq_lens[:raw_bs] = forward_batch.seq_lens
self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......@@ -271,15 +271,15 @@ class CudaGraphRunner:
)
# Extract logprobs
if input_metadata.return_logprob:
if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=input_metadata.top_logprobs_nums,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
......
......@@ -18,7 +18,7 @@ limitations under the License.
"""Meta data for a forward pass."""
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Set
from typing import TYPE_CHECKING, List
import numpy as np
import torch
......@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
@dataclass
class InputMetadata:
"""Store all inforamtion of a forward pass."""
class ForwardBatch:
"""Store all inputs of a forward pass."""
# The forward mode
forward_mode: ForwardMode
......
......@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
......@@ -466,47 +466,47 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, input_metadata: InputMetadata):
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
input_metadata.batch_size
forward_batch.batch_size
):
return self.cuda_graph_runner.replay(input_metadata)
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward(
input_metadata.input_ids, input_metadata.positions, input_metadata
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward_extend(self, input_metadata: InputMetadata):
def forward_extend(self, forward_batch: ForwardBatch):
if self.is_generation:
return self.model.forward(
input_metadata.input_ids, input_metadata.positions, input_metadata
forward_batch.input_ids, forward_batch.positions, forward_batch
)
else:
# Only embedding models have get_embedding parameter
return self.model.forward(
input_metadata.input_ids,
input_metadata.positions,
input_metadata,
forward_batch.input_ids,
forward_batch.positions,
forward_batch,
get_embedding=True,
)
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput:
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information
input_metadata.req_to_token_pool = self.req_to_token_pool
input_metadata.token_to_kv_pool = self.token_to_kv_pool
input_metadata.attn_backend = self.attn_backend
input_metadata.attn_backend.init_forward_metadata(input_metadata)
forward_batch.req_to_token_pool = self.req_to_token_pool
forward_batch.token_to_kv_pool = self.token_to_kv_pool
forward_batch.attn_backend = self.attn_backend
forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(input_metadata)
self.lora_manager.prepare_lora_batch(forward_batch)
if input_metadata.forward_mode.is_decode():
return self.forward_decode(input_metadata)
elif input_metadata.forward_mode.is_extend():
return self.forward_extend(input_metadata)
if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch)
else:
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}")
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
......
......@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
......@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
......@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
LoraConfig = None
......@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
q,
k,
v,
input_metadata,
forward_batch,
)
attn_output, _ = self.dense(context_layer)
return attn_output
......@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
......@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Residual connection.
......@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
for i in range(self.num_layers):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Final layer norm.
if self.post_layer_norm:
......@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
......@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
return hidden_states
......@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs
......@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states_mlp = self.mlp(hidden_states)
# Add everything together
......@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
......@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
input_metadata,
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs
......@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
hidden_states, _ = self.out_proj(attn_output)
return hidden_states
......@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + x
residual = hidden_states
......@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
......@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
hidden_states = input_embeds
for i in range(len(self.blocks)):
block = self.blocks[i]
hidden_states = block(position_ids, hidden_states, input_metadata)
hidden_states = block(position_ids, hidden_states, forward_batch)
hidden_states = self.norm_f(hidden_states)
return hidden_states
......@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class DeepseekMLP(nn.Module):
......@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment