"docs/vscode:/vscode.git/clone" did not exist on "bdaa130997e28486ac20ebfa21908e9009e8ec97"
Unverified Commit 36d5acfc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename InputMetadata -> ForwardBatch (#1543)

parent 3f0fe08d
...@@ -30,6 +30,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla ...@@ -30,6 +30,6 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
- Remove `Sample`. - Remove `Sample`.
- Change `forward()` functions, and add `input_metadata`. - Change `forward()` functions, and add `forward_batch`.
- Add `EntryClass` at the end. - Add `EntryClass` at the end.
...@@ -225,16 +225,16 @@ def extend(reqs, model_runner): ...@@ -225,16 +225,16 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend(model_runner.model_config.vocab_size)
input_metadata = batch.get_input_metadata() forward_batch = batch.get_forward_batch()
logits_output = model_runner.forward(input_metadata) logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist() next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.prepare_for_decode(input_token_ids)
input_metadata = batch.get_input_metadata() forward_batch = batch.get_forward_batch()
logits_output = model_runner.forward(input_metadata) logits_output = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist() next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
......
...@@ -16,7 +16,7 @@ import torch.nn as nn ...@@ -16,7 +16,7 @@ import torch.nn as nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -37,7 +37,7 @@ class AttentionBackend(ABC): ...@@ -37,7 +37,7 @@ class AttentionBackend(ABC):
"""The base class of attention backends""" """The base class of attention backends"""
@abstractmethod @abstractmethod
def init_forward_metadata(self, input_metadata: InputMetadata): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
raise NotImplementedError() raise NotImplementedError()
...@@ -61,18 +61,18 @@ class AttentionBackend(ABC): ...@@ -61,18 +61,18 @@ class AttentionBackend(ABC):
"""Get the fill value for padded seq lens. Typically, it is 0 or 1.""" """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
raise NotImplementedError() raise NotImplementedError()
def forward(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
if input_metadata.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(q, k, v, layer, input_metadata) return self.forward_decode(q, k, v, layer, forward_batch)
else: else:
return self.forward_extend(q, k, v, layer, input_metadata) return self.forward_extend(q, k, v, layer, forward_batch)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run a forward for decode.""" """Run a forward for decode."""
raise NotImplementedError() raise NotImplementedError()
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
"""Run a forward for extend.""" """Run a forward for extend."""
raise NotImplementedError() raise NotImplementedError()
...@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -131,31 +131,31 @@ class FlashInferAttnBackend(AttentionBackend):
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_metadata = {} self.cuda_graph_metadata = {}
def init_forward_metadata(self, input_metadata: InputMetadata): def init_forward_metadata(self, forward_batch: ForwardBatch):
if input_metadata.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
prefix_lens = None prefix_lens = None
use_ragged = False use_ragged = False
extend_no_prefix = False extend_no_prefix = False
total_num_tokens = None total_num_tokens = None
else: else:
prefix_lens = input_metadata.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
# Some heuristics to check whether to use ragged forward # Some heuristics to check whether to use ragged forward
use_ragged = False use_ragged = False
if ( if (
torch.sum(input_metadata.seq_lens).item() >= 4096 torch.sum(forward_batch.seq_lens).item() >= 4096
and self.model_runner.sliding_window_size is None and self.model_runner.sliding_window_size is None
): ):
use_ragged = True use_ragged = True
total_num_tokens = torch.sum(input_metadata.seq_lens).item() total_num_tokens = torch.sum(forward_batch.seq_lens).item()
extend_no_prefix = not torch.any(input_metadata.extend_prefix_lens).item() extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
update_flashinfer_indices( update_flashinfer_indices(
input_metadata.forward_mode, forward_batch.forward_mode,
self.model_runner, self.model_runner,
input_metadata.req_pool_indices, forward_batch.req_pool_indices,
input_metadata.seq_lens, forward_batch.seq_lens,
prefix_lens, prefix_lens,
use_ragged=use_ragged, use_ragged=use_ragged,
) )
...@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -248,7 +248,7 @@ class FlashInferAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 0
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
if not isinstance(self.prefill_wrapper_paged, list): if not isinstance(self.prefill_wrapper_paged, list):
prefill_wrapper_paged = self.prefill_wrapper_paged prefill_wrapper_paged = self.prefill_wrapper_paged
else: else:
...@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -264,12 +264,12 @@ class FlashInferAttnBackend(AttentionBackend):
if not use_ragged: if not use_ragged:
if k is not None: if k is not None:
assert v is not None assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, forward_batch.out_cache_loc, k, v
) )
o = prefill_wrapper_paged.forward( o = prefill_wrapper_paged.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=True, causal=True,
sm_scale=layer.scaling, sm_scale=layer.scaling,
window_left=layer.sliding_window_size, window_left=layer.sliding_window_size,
...@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -290,7 +290,7 @@ class FlashInferAttnBackend(AttentionBackend):
else: else:
o2, s2 = prefill_wrapper_paged.forward_return_lse( o2, s2 = prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=False, causal=False,
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
...@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -298,13 +298,13 @@ class FlashInferAttnBackend(AttentionBackend):
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
input_metadata.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, forward_batch.out_cache_loc, k, v
) )
return o.view(-1, layer.tp_q_head_num * layer.head_dim) return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = ( use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
self.forward_metadata self.forward_metadata
) )
...@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -317,13 +317,13 @@ class FlashInferAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
input_metadata.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, forward_batch.out_cache_loc, k, v
) )
o = decode_wrapper.forward( o = decode_wrapper.forward(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
sm_scale=layer.scaling, sm_scale=layer.scaling,
logits_soft_cap=layer.logit_cap, logits_soft_cap=layer.logit_cap,
) )
...@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend): ...@@ -358,26 +358,26 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_max_seq_len = model_runner.model_config.context_len self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(self, input_metadata: InputMetadata): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
if input_metadata.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32) start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(input_metadata.seq_lens).item() total_num_tokens = torch.sum(forward_batch.seq_lens).item()
attn_logits = torch.empty( attn_logits = torch.empty(
(self.num_head, total_num_tokens), (self.num_head, total_num_tokens),
dtype=self.reduce_dtype, dtype=self.reduce_dtype,
device="cuda", device="cuda",
) )
max_seq_len = torch.max(input_metadata.seq_lens).item() max_seq_len = torch.max(forward_batch.seq_lens).item()
max_extend_len = None max_extend_len = None
else: else:
start_loc = attn_logits = max_seq_len = None start_loc = attn_logits = max_seq_len = None
prefix_lens = input_metadata.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
...@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend): ...@@ -415,15 +415,15 @@ class TritonAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 1 return 1
def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# TODO: reuse the buffer across layers # TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
else: else:
o = torch.empty_like(q) o = torch.empty_like(q)
input_metadata.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, forward_batch.out_cache_loc, k, v
) )
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
...@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend): ...@@ -432,20 +432,20 @@ class TritonAttnBackend(AttentionBackend):
k.contiguous(), k.contiguous(),
v.contiguous(), v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
input_metadata.req_to_token_pool.req_to_token, forward_batch.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices, forward_batch.req_pool_indices,
input_metadata.seq_lens, forward_batch.seq_lens,
input_metadata.extend_seq_lens, forward_batch.extend_seq_lens,
input_metadata.extend_start_loc, forward_batch.extend_start_loc,
max_extend_len, max_extend_len,
layer.scaling, layer.scaling,
layer.logit_cap, layer.logit_cap,
) )
return o return o
def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
# During torch.compile, there is a bug in rotary_emb that causes the # During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly. # output value to have a 3D tensor shape. This reshapes the output correctly.
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
...@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend): ...@@ -458,19 +458,19 @@ class TritonAttnBackend(AttentionBackend):
start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
input_metadata.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer.layer_id, input_metadata.out_cache_loc, k, v layer.layer_id, forward_batch.out_cache_loc, k, v
) )
self.decode_attention_fwd( self.decode_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim), o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
input_metadata.req_to_token_pool.req_to_token, forward_batch.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices, forward_batch.req_pool_indices,
start_loc, start_loc,
input_metadata.seq_lens, forward_batch.seq_lens,
attn_logits, attn_logits,
max_seq_len, max_seq_len,
layer.scaling, layer.scaling,
......
...@@ -25,7 +25,7 @@ from vllm.distributed import ( ...@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@dataclasses.dataclass @dataclasses.dataclass
...@@ -61,26 +61,26 @@ class LogitsMetadata: ...@@ -61,26 +61,26 @@ class LogitsMetadata:
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
@classmethod @classmethod
def from_input_metadata(cls, input_metadata: InputMetadata): def from_forward_batch(cls, forward_batch: ForwardBatch):
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if input_metadata.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
extend_logprob_pruned_lens_cpu = [ extend_logprob_pruned_lens_cpu = [
extend_len - start_len extend_len - start_len
for extend_len, start_len in zip( for extend_len, start_len in zip(
input_metadata.extend_seq_lens, forward_batch.extend_seq_lens,
input_metadata.extend_logprob_start_lens_cpu, forward_batch.extend_logprob_start_lens_cpu,
) )
] ]
else: else:
extend_logprob_pruned_lens_cpu = None extend_logprob_pruned_lens_cpu = None
return cls( return cls(
forward_mode=input_metadata.forward_mode, forward_mode=forward_batch.forward_mode,
top_logprobs_nums=input_metadata.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
return_logprob=input_metadata.return_logprob, return_logprob=forward_batch.return_logprob,
return_top_logprob=return_top_logprob, return_top_logprob=return_top_logprob,
extend_seq_lens=input_metadata.extend_seq_lens, extend_seq_lens=forward_batch.extend_seq_lens,
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
) )
...@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module): ...@@ -162,10 +162,10 @@ class LogitsProcessor(nn.Module):
input_ids, input_ids,
hidden_states, hidden_states,
weight, weight,
logits_metadata: Union[LogitsMetadata, InputMetadata], logits_metadata: Union[LogitsMetadata, ForwardBatch],
): ):
if isinstance(logits_metadata, InputMetadata): if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata) logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata) assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction # Get the last hidden states and last logits for the next token prediction
......
...@@ -7,7 +7,7 @@ from enum import IntEnum ...@@ -7,7 +7,7 @@ from enum import IntEnum
import torch import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import ForwardBatch
class PoolingType(IntEnum): class PoolingType(IntEnum):
...@@ -36,10 +36,10 @@ class Pooler(nn.Module): ...@@ -36,10 +36,10 @@ class Pooler(nn.Module):
self.normalize = normalize self.normalize = normalize
def forward( def forward(
self, hidden_states: torch.Tensor, input_metadata: InputMetadata self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
if self.pooling_type == PoolingType.LAST: if self.pooling_type == PoolingType.LAST:
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1 last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
pooled_data = hidden_states[last_token_indices] pooled_data = hidden_states[last_token_indices]
else: else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}") raise ValueError(f"Invalid pooling type: {self.pooling_type}")
......
...@@ -17,7 +17,7 @@ limitations under the License. ...@@ -17,7 +17,7 @@ limitations under the License.
from torch import nn from torch import nn
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
...@@ -48,11 +48,11 @@ class RadixAttention(nn.Module): ...@@ -48,11 +48,11 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1 self.sliding_window_size = sliding_window_size or -1
def forward(self, q, k, v, input_metadata: InputMetadata): def forward(self, q, k, v, forward_batch: ForwardBatch):
if k is not None: if k is not None:
# For cross-layer sharing, kv can be None # For cross-layer sharing, kv can be None
assert v is not None assert v is not None
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata) return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
) )
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later # ROCm: flashinfer available later
...@@ -207,9 +207,9 @@ class LoRAManager: ...@@ -207,9 +207,9 @@ class LoRAManager:
if lora_weight_name: if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
def prepare_lora_batch(self, input_metadata: InputMetadata): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool # load active loras into lora memory pool
cur_uids = set(input_metadata.lora_paths) cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
i = 0 i = 0
evictable_uids = list(self.active_uids) evictable_uids = list(self.active_uids)
...@@ -229,14 +229,14 @@ class LoRAManager: ...@@ -229,14 +229,14 @@ class LoRAManager:
return return
# setup lora in forward modules # setup lora in forward modules
bs = input_metadata.batch_size bs = forward_batch.batch_size
seg_lens = ( seg_lens = (
input_metadata.extend_seq_lens forward_batch.extend_seq_lens
if input_metadata.forward_mode.is_extend() if forward_batch.forward_mode.is_extend()
else torch.ones(bs) else torch.ones(bs)
) )
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(input_metadata.lora_paths): for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path] weight_indices[i] = self.buffer_id[lora_path]
for module_name, module in self.lora_modules: for module_name, module in self.lora_modules:
......
...@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap ...@@ -29,7 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -511,8 +511,8 @@ class ScheduleBatch: ...@@ -511,8 +511,8 @@ class ScheduleBatch:
self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size) self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
def get_input_metadata(self): def get_forward_batch(self):
return InputMetadata.from_schedule_batch(self) return ForwardBatch.from_schedule_batch(self)
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
self.forward_mode = ForwardMode.MIXED self.forward_mode = ForwardMode.MIXED
......
...@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode ...@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096")) CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
class SchedulerPolicy: class SchedulePolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]: if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled. # LPM and DFS-weight is meaningless when the tree cache is disabled.
......
...@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -50,8 +50,8 @@ from sglang.srt.managers.schedule_batch import (
Req, Req,
ScheduleBatch, ScheduleBatch,
) )
from sglang.srt.managers.scheduler_policy import PrefillAdder, SchedulerPolicy from sglang.srt.managers.schedule_policy import PrefillAdder, SchedulePolicy
from sglang.srt.managers.tp_worker import ModelTpWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -134,7 +134,7 @@ class Scheduler: ...@@ -134,7 +134,7 @@ class Scheduler:
) )
# Launch a tensor parallel worker # Launch a tensor parallel worker
self.tp_worker = ModelTpWorker( self.tp_worker = TpModelWorker(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
server_args=server_args, server_args=server_args,
...@@ -179,7 +179,7 @@ class Scheduler: ...@@ -179,7 +179,7 @@ class Scheduler:
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulerPolicy(self.schedule_policy, self.tree_cache) self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
...@@ -575,9 +575,9 @@ class Scheduler: ...@@ -575,9 +575,9 @@ class Scheduler:
if self.is_generation: if self.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
input_metadata = batch.get_input_metadata() forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch forward_batch, batch
) )
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
...@@ -641,8 +641,8 @@ class Scheduler: ...@@ -641,8 +641,8 @@ class Scheduler:
) )
else: else:
assert batch.extend_num_tokens != 0 assert batch.extend_num_tokens != 0
input_metadata = batch.get_input_metadata() forward_batch = batch.get_forward_batch()
embeddings = self.tp_worker.forward_batch_embedding(input_metadata) embeddings = self.tp_worker.forward_batch_embedding(forward_batch)
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
...@@ -771,9 +771,9 @@ class Scheduler: ...@@ -771,9 +771,9 @@ class Scheduler:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
input_metadata = batch.get_input_metadata() forward_batch = batch.get_forward_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
input_metadata, batch forward_batch, batch
) )
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
......
...@@ -21,7 +21,7 @@ import logging ...@@ -21,7 +21,7 @@ import logging
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
...@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se ...@@ -29,7 +29,9 @@ from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_se
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelTpWorker: class TpModelWorker:
"""A tensor parallel model worker."""
def __init__( def __init__(
self, self,
gpu_id: int, gpu_id: int,
...@@ -106,13 +108,13 @@ class ModelTpWorker: ...@@ -106,13 +108,13 @@ class ModelTpWorker:
self.random_seed, self.random_seed,
) )
def forward_batch_generation(self, input_metadata: InputMetadata, batch): def forward_batch_generation(self, forward_batch: ForwardBatch, batch):
logits_output = self.model_runner.forward(input_metadata) logits_output = self.model_runner.forward(forward_batch)
next_token_ids = self.model_runner.sample(logits_output, batch) next_token_ids = self.model_runner.sample(logits_output, batch)
return logits_output, next_token_ids return logits_output, next_token_ids
def forward_batch_embedding(self, input_metadata: InputMetadata): def forward_batch_embedding(self, forward_batch: ForwardBatch):
logits_output = self.model_runner.forward(input_metadata) logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist() embeddings = logits_output.embeddings.tolist()
return embeddings return embeddings
......
...@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import ( ...@@ -31,7 +31,7 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor, LogitsProcessor,
LogitsProcessorOutput, LogitsProcessorOutput,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -196,7 +196,7 @@ class CudaGraphRunner: ...@@ -196,7 +196,7 @@ class CudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
input_metadata = InputMetadata( forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
batch_size=bs, batch_size=bs,
input_ids=input_ids, input_ids=input_ids,
...@@ -210,7 +210,7 @@ class CudaGraphRunner: ...@@ -210,7 +210,7 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
) )
return forward(input_ids, input_metadata.positions, input_metadata) return forward(input_ids, forward_batch.positions, forward_batch)
for _ in range(2): for _ in range(2):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -233,9 +233,9 @@ class CudaGraphRunner: ...@@ -233,9 +233,9 @@ class CudaGraphRunner:
self.graph_memory_pool = graph.pool() self.graph_memory_pool = graph.pool()
return graph, out return graph, out
def replay(self, input_metadata: InputMetadata): def replay(self, forward_batch: ForwardBatch):
assert input_metadata.out_cache_loc is not None assert forward_batch.out_cache_loc is not None
raw_bs = input_metadata.batch_size raw_bs = forward_batch.batch_size
# Pad # Pad
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
...@@ -245,10 +245,10 @@ class CudaGraphRunner: ...@@ -245,10 +245,10 @@ class CudaGraphRunner:
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
self.input_ids[:raw_bs] = input_metadata.input_ids self.input_ids[:raw_bs] = forward_batch.input_ids
self.req_pool_indices[:raw_bs] = input_metadata.req_pool_indices self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
self.seq_lens[:raw_bs] = input_metadata.seq_lens self.seq_lens[:raw_bs] = forward_batch.seq_lens
self.out_cache_loc[:raw_bs] = input_metadata.out_cache_loc self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
...@@ -271,15 +271,15 @@ class CudaGraphRunner: ...@@ -271,15 +271,15 @@ class CudaGraphRunner:
) )
# Extract logprobs # Extract logprobs
if input_metadata.return_logprob: if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1 logits_output.next_token_logits, dim=-1
) )
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
logits_metadata = LogitsMetadata( logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
top_logprobs_nums=input_metadata.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
) )
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata logits_output.next_token_logprobs, logits_metadata
......
...@@ -18,7 +18,7 @@ limitations under the License. ...@@ -18,7 +18,7 @@ limitations under the License.
"""Meta data for a forward pass.""" """Meta data for a forward pass."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Set from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
import torch import torch
...@@ -53,8 +53,8 @@ class ForwardMode(IntEnum): ...@@ -53,8 +53,8 @@ class ForwardMode(IntEnum):
@dataclass @dataclass
class InputMetadata: class ForwardBatch:
"""Store all inforamtion of a forward pass.""" """Store all inputs of a forward pass."""
# The forward mode # The forward mode
forward_mode: ForwardMode forward_mode: ForwardMode
......
...@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -48,7 +48,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -466,47 +466,47 @@ class ModelRunner: ...@@ -466,47 +466,47 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
def forward_decode(self, input_metadata: InputMetadata): def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run( if self.cuda_graph_runner and self.cuda_graph_runner.can_run(
input_metadata.batch_size forward_batch.batch_size
): ):
return self.cuda_graph_runner.replay(input_metadata) return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward( return self.model.forward(
input_metadata.input_ids, input_metadata.positions, input_metadata forward_batch.input_ids, forward_batch.positions, forward_batch
) )
def forward_extend(self, input_metadata: InputMetadata): def forward_extend(self, forward_batch: ForwardBatch):
if self.is_generation: if self.is_generation:
return self.model.forward( return self.model.forward(
input_metadata.input_ids, input_metadata.positions, input_metadata forward_batch.input_ids, forward_batch.positions, forward_batch
) )
else: else:
# Only embedding models have get_embedding parameter # Only embedding models have get_embedding parameter
return self.model.forward( return self.model.forward(
input_metadata.input_ids, forward_batch.input_ids,
input_metadata.positions, forward_batch.positions,
input_metadata, forward_batch,
get_embedding=True, get_embedding=True,
) )
def forward(self, input_metadata: InputMetadata) -> LogitsProcessorOutput: def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
# Attach attention information # Attach attention information
input_metadata.req_to_token_pool = self.req_to_token_pool forward_batch.req_to_token_pool = self.req_to_token_pool
input_metadata.token_to_kv_pool = self.token_to_kv_pool forward_batch.token_to_kv_pool = self.token_to_kv_pool
input_metadata.attn_backend = self.attn_backend forward_batch.attn_backend = self.attn_backend
input_metadata.attn_backend.init_forward_metadata(input_metadata) forward_batch.attn_backend.init_forward_metadata(forward_batch)
# Attach lora information # Attach lora information
if self.server_args.lora_paths is not None: if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(input_metadata) self.lora_manager.prepare_lora_batch(forward_batch)
if input_metadata.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(input_metadata) return self.forward_decode(forward_batch)
elif input_metadata.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
return self.forward_extend(input_metadata) return self.forward_extend(forward_batch)
else: else:
raise ValueError(f"Invaid forward mode: {input_metadata.forward_mode}") raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
def _apply_logits_bias( def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
......
...@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm ...@@ -46,7 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module): ...@@ -189,13 +189,13 @@ class BaiChuanAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -237,7 +237,7 @@ class BaiChuanDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -249,7 +249,7 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module): ...@@ -292,7 +292,7 @@ class BaiChuanModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module): ...@@ -301,7 +301,7 @@ class BaiChuanModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -350,11 +350,11 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import ( ...@@ -42,7 +42,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
LoraConfig = None LoraConfig = None
...@@ -118,7 +118,7 @@ class GLMAttention(nn.Module): ...@@ -118,7 +118,7 @@ class GLMAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -127,7 +127,7 @@ class GLMAttention(nn.Module): ...@@ -127,7 +127,7 @@ class GLMAttention(nn.Module):
q, q,
k, k,
v, v,
input_metadata, forward_batch,
) )
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
...@@ -220,7 +220,7 @@ class GLMBlock(nn.Module): ...@@ -220,7 +220,7 @@ class GLMBlock(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
# hidden_states: [num_tokens, h] # hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -229,7 +229,7 @@ class GLMBlock(nn.Module): ...@@ -229,7 +229,7 @@ class GLMBlock(nn.Module):
attention_output = self.self_attention( attention_output = self.self_attention(
hidden_states=layernorm_output, hidden_states=layernorm_output,
position_ids=position_ids, position_ids=position_ids,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Residual connection. # Residual connection.
...@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module): ...@@ -288,14 +288,14 @@ class GLMTransformer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(self.num_layers): for i in range(self.num_layers):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=position_ids, position_ids=position_ids,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Final layer norm. # Final layer norm.
if self.post_layer_norm: if self.post_layer_norm:
...@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module): ...@@ -328,7 +328,7 @@ class ChatGLMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
...@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module): ...@@ -336,7 +336,7 @@ class ChatGLMModel(nn.Module):
hidden_states = self.encoder( hidden_states = self.encoder(
hidden_states=inputs_embeds, hidden_states=inputs_embeds,
position_ids=position_ids, position_ids=position_ids,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
return hidden_states return hidden_states
...@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -376,11 +376,11 @@ class ChatGLMForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import ( ...@@ -63,7 +63,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -220,14 +220,14 @@ class CohereAttention(nn.Module): ...@@ -220,14 +220,14 @@ class CohereAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm: if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module): ...@@ -255,7 +255,7 @@ class CohereDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module): ...@@ -264,7 +264,7 @@ class CohereDecoderLayer(nn.Module):
hidden_states_attention = self.self_attn( hidden_states_attention = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states_mlp = self.mlp(hidden_states) hidden_states_mlp = self.mlp(hidden_states)
# Add everything together # Add everything together
...@@ -299,7 +299,7 @@ class CohereModel(nn.Module): ...@@ -299,7 +299,7 @@ class CohereModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -308,7 +308,7 @@ class CohereModel(nn.Module): ...@@ -308,7 +308,7 @@ class CohereModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module): ...@@ -333,15 +333,15 @@ class CohereForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
positions, positions,
input_metadata, forward_batch,
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import ( ...@@ -44,7 +44,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
...@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module): ...@@ -249,14 +249,14 @@ class DbrxAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
hidden_states, _ = self.out_proj(attn_output) hidden_states, _ = self.out_proj(attn_output)
return hidden_states return hidden_states
...@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -278,14 +278,14 @@ class DbrxFusedNormAttention(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.norm_1(hidden_states) hidden_states = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = residual + x hidden_states = residual + x
residual = hidden_states residual = hidden_states
...@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module): ...@@ -310,12 +310,12 @@ class DbrxBlock(nn.Module):
self, self,
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.norm_attn_norm( hidden_states, residual = self.norm_attn_norm(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = self.ffn(hidden_states) hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
...@@ -349,7 +349,7 @@ class DbrxModel(nn.Module): ...@@ -349,7 +349,7 @@ class DbrxModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -358,7 +358,7 @@ class DbrxModel(nn.Module): ...@@ -358,7 +358,7 @@ class DbrxModel(nn.Module):
hidden_states = input_embeds hidden_states = input_embeds
for i in range(len(self.blocks)): for i in range(len(self.blocks)):
block = self.blocks[i] block = self.blocks[i]
hidden_states = block(position_ids, hidden_states, input_metadata) hidden_states = block(position_ids, hidden_states, forward_batch)
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
...@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module): ...@@ -388,11 +388,11 @@ class DbrxForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import ( ...@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module): ...@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module): ...@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module): ...@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment