Unverified Commit 69b3bb9a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Unify forward mode (#1360)

parent 689ff588
......@@ -60,7 +60,6 @@ import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
......@@ -208,14 +207,14 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
return next_token_ids, logits_output.next_token_logits
......
......@@ -103,7 +103,7 @@ class LogitsProcessor(nn.Module):
@staticmethod
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
if logits_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode.is_decode():
output_top_logprobs = []
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
......@@ -163,7 +163,7 @@ class LogitsProcessor(nn.Module):
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction
if logits_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode.is_decode():
last_index = None
last_hidden = hidden_states
else:
......@@ -195,7 +195,7 @@ class LogitsProcessor(nn.Module):
)
else:
# When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode.is_decode():
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
# Get the logprob of top-k tokens
......
......@@ -197,9 +197,9 @@ class RadixAttention(nn.Module):
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode == ForwardMode.EXTEND:
if input_metadata.forward_mode.is_extend():
return self.extend_forward(q, k, v, input_metadata)
elif input_metadata.forward_mode == ForwardMode.DECODE:
elif input_metadata.forward_mode.is_decode():
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
......
......@@ -29,6 +29,7 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
if TYPE_CHECKING:
......@@ -334,6 +335,8 @@ class ScheduleBatch:
token_to_kv_pool: BaseTokenToKVPool
tree_cache: BasePrefixCache
forward_mode: ForwardMode = None
# Batched arguments to model runner
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
......@@ -397,6 +400,8 @@ class ScheduleBatch:
return out_cache_loc
def prepare_for_extend(self, vocab_size: int):
self.forward_mode = ForwardMode.EXTEND
bs = self.batch_size()
reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
......@@ -626,6 +631,8 @@ class ScheduleBatch:
return jump_forward_reqs
def prepare_for_decode(self, input_ids=None):
self.forward_mode = ForwardMode.DECODE
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
......
......@@ -53,7 +53,6 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
......@@ -521,9 +520,7 @@ class ModelTpServer:
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.EXTEND
)
sample_output, logits_output = self.model_runner.forward(batch)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
......@@ -588,7 +585,7 @@ class ModelTpServer:
pt += req.extend_input_len
else:
assert batch.extend_num_tokens != 0
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
logits_output = self.model_runner.forward(batch)
embeddings = logits_output.embeddings.tolist()
# Check finish conditions
......@@ -699,9 +696,7 @@ class ModelTpServer:
batch.prepare_for_decode()
# Forward and sample the next tokens
sample_output, logits_output = self.model_runner.forward(
batch, ForwardMode.DECODE
)
sample_output, logits_output = self.model_runner.forward(batch)
next_token_ids = batch.check_sample_results(sample_output)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
......
......@@ -25,10 +25,9 @@ import torch
import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
......@@ -41,6 +40,15 @@ class ForwardMode(IntEnum):
# Decode one token.
DECODE = auto()
def is_prefill(self):
return self == ForwardMode.PREFILL
def is_extend(self):
return self == ForwardMode.EXTEND
def is_decode(self):
return self == ForwardMode.DECODE
@dataclass
class InputMetadata:
......@@ -102,7 +110,7 @@ class InputMetadata:
def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets
if self.forward_mode == ForwardMode.DECODE:
if self.forward_mode.is_decode():
if True:
self.positions = self.seq_lens - 1
else:
......@@ -141,7 +149,7 @@ class InputMetadata:
self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch):
if self.forward_mode == ForwardMode.DECODE:
if self.forward_mode.is_decode():
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None
else:
......@@ -173,10 +181,9 @@ class InputMetadata:
cls,
model_runner: "ModelRunner",
batch: ScheduleBatch,
forward_mode: ForwardMode,
):
ret = cls(
forward_mode=forward_mode,
forward_mode=batch.forward_mode,
sampling_info=batch.sampling_info,
batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices,
......@@ -194,13 +201,11 @@ class InputMetadata:
ret.compute_extend_infos(batch)
if (
forward_mode != ForwardMode.DECODE
or model_runner.server_args.disable_flashinfer
):
fm = batch.forward_mode
if not fm.is_decode() or model_runner.server_args.disable_flashinfer:
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if forward_mode != ForwardMode.DECODE:
if not fm.is_decode():
ret.init_multimuldal_info(batch)
if model_runner.server_args.disable_flashinfer:
......@@ -209,7 +214,7 @@ class InputMetadata:
flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer:
if (
forward_mode != ForwardMode.DECODE
not fm.is_decode()
and int(torch.sum(ret.seq_lens)) > 4096
and model_runner.sliding_window_size is None
):
......@@ -226,7 +231,7 @@ class InputMetadata:
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE:
if self.forward_mode.is_decode():
self.triton_max_extend_len = None
else:
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
......@@ -239,7 +244,7 @@ class InputMetadata:
prefix_lens_cpu,
flashinfer_use_ragged,
):
if self.forward_mode == ForwardMode.DECODE:
if self.forward_mode.is_decode():
prefix_lens = None
else:
prefix_lens = self.extend_prefix_lens
......@@ -339,7 +344,7 @@ def update_flashinfer_indices(
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE:
if forward_mode.is_decode():
# CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
......@@ -388,7 +393,7 @@ def update_flashinfer_indices(
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
for wrapper_id in range(2):
if wrapper_id == 0:
if forward_mode == ForwardMode.DECODE:
if forward_mode.is_decode():
paged_kernel_lens = torch.minimum(
seq_lens, torch.tensor(model_runner.sliding_window_size + 1)
)
......@@ -418,7 +423,7 @@ def update_flashinfer_indices(
kv_indices,
)
if forward_mode == ForwardMode.DECODE:
if forward_mode.is_decode():
# CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
......
......@@ -530,11 +530,7 @@ class ModelRunner:
):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(
self,
batch,
ForwardMode.DECODE,
)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......@@ -542,11 +538,7 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(
self,
batch,
forward_mode=ForwardMode.EXTEND,
)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
if self.is_generation:
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
......@@ -562,11 +554,7 @@ class ModelRunner:
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(
self,
batch,
forward_mode=ForwardMode.EXTEND,
)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids,
input_metadata.positions,
......@@ -577,16 +565,18 @@ class ModelRunner:
)
def forward(
self, batch: ScheduleBatch, forward_mode: ForwardMode
self, batch: ScheduleBatch
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend():
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE:
elif batch.forward_mode.is_decode():
return self.forward_decode(batch)
elif forward_mode == ForwardMode.EXTEND:
elif batch.forward_mode.is_extend():
return self.forward_extend(batch)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
@lru_cache()
......
......@@ -136,7 +136,7 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
......@@ -357,7 +357,7 @@ class LlavaBaseForCausalLM(nn.Module):
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -116,7 +116,7 @@ class LlavaVidForCausalLM(nn.Module):
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
# Embed text inputs
......@@ -199,7 +199,7 @@ class LlavaVidForCausalLM(nn.Module):
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment