Unverified Commit 1ac304ee authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Adjust `InputeMetadata` and `ScheduleBatch` (#981)

parent 20a4f927
...@@ -307,7 +307,6 @@ class ScheduleBatch: ...@@ -307,7 +307,6 @@ class ScheduleBatch:
input_ids: torch.Tensor = None input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None extend_num_tokens: int = None
...@@ -316,11 +315,6 @@ class ScheduleBatch: ...@@ -316,11 +315,6 @@ class ScheduleBatch:
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None
image_offsets: List[int] = None
# Batched sampling params # Batched sampling params
temperatures: torch.Tensor = None temperatures: torch.Tensor = None
top_ps: torch.Tensor = None top_ps: torch.Tensor = None
...@@ -412,59 +406,40 @@ class ScheduleBatch: ...@@ -412,59 +406,40 @@ class ScheduleBatch:
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = self.batch_size() bs = self.batch_size()
reqs = self.reqs reqs = self.reqs
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
prefix_indices = [r.prefix_indices for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids)
# Handle prefix
extend_lens = []
prefix_lens = []
seq_lens = [] seq_lens = []
# Allocate memory
req_pool_indices_cpu = self.alloc_req_slots(bs) req_pool_indices_cpu = self.alloc_req_slots(bs)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i] req.req_pool_idx = req_pool_indices_cpu[i]
extend_lens.append(len(input_ids[i])) pre_len, seq_len = len(req.prefix_indices), len(req.input_ids)
ext_len = seq_len - pre_len
seq_lens.append(seq_len)
if len(prefix_indices[i]) == 0: if pre_len > 0:
prefix_lens.append(0)
else:
prefix_lens.append(len(prefix_indices[i]))
self.req_to_token_pool.req_to_token[req.req_pool_idx][ self.req_to_token_pool.req_to_token[req.req_pool_idx][
: len(prefix_indices[i]) :pre_len
] = prefix_indices[i] ] = req.prefix_indices
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
# Allocate memory self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens) out_cache_loc[pt : pt + ext_len]
extend_num_tokens = seq_lens.sum() - prefix_lens.sum() )
out_cache_loc = self.alloc_token_slots(extend_num_tokens) pt += ext_len
pt = 0
for i, req in enumerate(reqs):
self.req_to_token_pool.req_to_token[req.req_pool_idx][
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]
# Set fields # Set fields
with torch.device("cuda"): with torch.device("cuda"):
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32) self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices_cpu) self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32) self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32) self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(r.image_offset - p_len) if r.image_offset is not None else 0
for r, p_len in zip(reqs, prefix_lens)
]
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
...@@ -642,7 +617,6 @@ class ScheduleBatch: ...@@ -642,7 +617,6 @@ class ScheduleBatch:
] ]
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1) self.seq_lens.add_(1)
self.prefix_lens = None
# Alloc mem # Alloc mem
bs = self.batch_size() bs = self.batch_size()
...@@ -667,7 +641,6 @@ class ScheduleBatch: ...@@ -667,7 +641,6 @@ class ScheduleBatch:
self.seq_lens = self.seq_lens[new_indices] self.seq_lens = self.seq_lens[new_indices]
self.input_ids = None self.input_ids = None
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices] self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = None self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
...@@ -692,7 +665,6 @@ class ScheduleBatch: ...@@ -692,7 +665,6 @@ class ScheduleBatch:
[self.req_pool_indices, other.req_pool_indices] [self.req_pool_indices, other.req_pool_indices]
) )
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.prefix_lens = None
self.position_ids_offsets = torch.concat( self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets] [self.position_ids_offsets, other.position_ids_offsets]
) )
......
...@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch ...@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
init_flashinfer_args, update_flashinfer_indices,
) )
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
...@@ -165,7 +165,7 @@ class CudaGraphRunner: ...@@ -165,7 +165,7 @@ class CudaGraphRunner:
paged_kv_indices_buffer=self.flashinfer_kv_indices, paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
) )
init_flashinfer_args( update_flashinfer_indices(
ForwardMode.DECODE, ForwardMode.DECODE,
self.model_runner, self.model_runner,
req_pool_indices, req_pool_indices,
...@@ -176,19 +176,19 @@ class CudaGraphRunner: ...@@ -176,19 +176,19 @@ class CudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
input_metadata = InputMetadata.create( input_metadata = InputMetadata(
self.model_runner,
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
batch_size=bs,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
prefix_lens=None, req_to_token_pool=self.model_runner.req_to_token_pool,
position_ids_offsets=position_ids_offsets, token_to_kv_pool=self.model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_logprob=False, return_logprob=False,
top_logprobs_nums=0, top_logprobs_nums=0,
skip_flashinfer_init=True, positions=(seq_lens - 1).to(torch.int64),
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
) )
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
return forward(input_ids, input_metadata.positions, input_metadata) return forward(input_ids, input_metadata.positions, input_metadata)
...@@ -222,7 +222,7 @@ class CudaGraphRunner: ...@@ -222,7 +222,7 @@ class CudaGraphRunner:
self.out_cache_loc[:raw_bs] = batch.out_cache_loc self.out_cache_loc[:raw_bs] = batch.out_cache_loc
# FlashInfer inputs # FlashInfer inputs
init_flashinfer_args( update_flashinfer_indices(
ForwardMode.DECODE, ForwardMode.DECODE,
self.model_runner, self.model_runner,
self.req_pool_indices[:bs], self.req_pool_indices[:bs],
......
...@@ -16,13 +16,17 @@ limitations under the License. ...@@ -16,13 +16,17 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
import torch import torch
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...@@ -39,25 +43,33 @@ class InputMetadata: ...@@ -39,25 +43,33 @@ class InputMetadata:
forward_mode: ForwardMode forward_mode: ForwardMode
batch_size: int batch_size: int
total_num_tokens: int
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool token_to_kv_pool: BaseTokenToKVPool
# For extend
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor
extend_no_prefix: bool
# Output location of the KV cache # Output location of the KV cache
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor
total_num_tokens: int = None
# Position information
positions: torch.Tensor = None
# For extend
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
extend_no_prefix: bool = None
# Output options # Output options
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
# For multimodal
pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None
image_offsets: List[int] = None
# Trition attention backend # Trition attention backend
triton_max_seq_len: int = 0 triton_max_seq_len: int = 0
triton_max_extend_len: int = 0 triton_max_extend_len: int = 0
...@@ -70,107 +82,170 @@ class InputMetadata: ...@@ -70,107 +82,170 @@ class InputMetadata:
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
flashinfer_use_ragged: bool = False flashinfer_use_ragged: bool = False
@classmethod def init_multimuldal_info(self, batch: ScheduleBatch):
def create( reqs = batch.reqs
cls, self.pixel_values = [r.pixel_values for r in reqs]
model_runner, self.image_sizes = [r.image_size for r in reqs]
forward_mode, self.image_offsets = [
req_pool_indices, (
seq_lens, (r.image_offset - len(r.prefix_indices))
prefix_lens, if r.image_offset is not None
position_ids_offsets, else 0
out_cache_loc,
top_logprobs_nums=None,
return_logprob=False,
skip_flashinfer_init=False,
):
flashinfer_use_ragged = False
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
flashinfer_use_ragged = True
init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged,
) )
for r in reqs
]
batch_size = len(req_pool_indices) def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets
if forward_mode == ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64) if True:
extend_seq_lens = extend_start_loc = extend_no_prefix = None self.positions = self.seq_lens - 1
if not model_runner.server_args.disable_flashinfer:
# This variable is not needed in this case,
# we do not compute it to make it compatbile with cuda graph.
total_num_tokens = None
else: else:
total_num_tokens = int(torch.sum(seq_lens)) # Deprecated
self.positions = (self.seq_lens - 1) + position_ids_offsets
else: else:
seq_lens_cpu = seq_lens.cpu().numpy() if True:
prefix_lens_cpu = prefix_lens.cpu().numpy() self.positions = torch.tensor(
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() np.concatenate(
positions = torch.tensor( [
np.concatenate( np.arange(len(req.prefix_indices), len(req.input_ids))
[ for req in batch.reqs
np.arange( ],
prefix_lens_cpu[i] + position_ids_offsets_cpu[i], axis=0,
seq_lens_cpu[i] + position_ids_offsets_cpu[i], ),
) device="cuda",
for i in range(batch_size) )
], else:
axis=0, # Deprecated
), position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
device="cuda", self.positions = torch.tensor(
) np.concatenate(
extend_seq_lens = seq_lens - prefix_lens [
extend_start_loc = torch.zeros_like(seq_lens) np.arange(
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) len(req.prefix_indices) + position_ids_offsets_cpu[i],
extend_no_prefix = torch.all(prefix_lens == 0) len(req.input_ids) + position_ids_offsets_cpu[i],
total_num_tokens = int(torch.sum(seq_lens)) )
for i, req in enumerate(batch.reqs)
],
axis=0,
),
device="cuda",
)
# Positions should be in long type
self.positions = self.positions.to(torch.int64)
def compute_extend_infos(self, batch: ScheduleBatch):
if self.forward_mode == ForwardMode.DECODE:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
else:
prefix_lens_cpu = [
len(r.input_ids) - len(r.prefix_indices) for r in batch.reqs
]
self.extend_seq_lens = torch.tensor(prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
def init_total_num_tokens(self, batch: ScheduleBatch):
self.total_num_tokens = sum(len(req.input_ids) for req in batch.reqs)
@classmethod
def from_schedule_batch(
cls,
model_runner: "ModelRunner",
batch: ScheduleBatch,
forward_mode: ForwardMode,
):
ret = cls( ret = cls(
forward_mode=forward_mode, forward_mode=forward_mode,
batch_size=batch_size, batch_size=batch.batch_size(),
total_num_tokens=total_num_tokens, req_pool_indices=batch.req_pool_indices,
req_pool_indices=req_pool_indices, seq_lens=batch.seq_lens,
seq_lens=seq_lens,
positions=positions,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=batch.out_cache_loc,
extend_seq_lens=extend_seq_lens, return_logprob=batch.return_logprob,
extend_start_loc=extend_start_loc, top_logprobs_nums=batch.top_logprobs_nums,
extend_no_prefix=extend_no_prefix,
return_logprob=return_logprob,
top_logprobs_nums=top_logprobs_nums,
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged=flashinfer_use_ragged,
) )
ret.compute_positions(batch)
ret.compute_extend_infos(batch)
ret.init_total_num_tokens(batch)
if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch)
prefix_lens = None
if forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
)
if model_runner.server_args.disable_flashinfer: if model_runner.server_args.disable_flashinfer:
( ret.init_triton_args(batch, prefix_lens)
ret.triton_max_seq_len,
ret.triton_max_extend_len, flashinfer_use_ragged = False
ret.triton_start_loc, if not model_runner.server_args.disable_flashinfer:
ret.triton_prefix_lens, if (
) = init_triton_args(forward_mode, seq_lens, prefix_lens) forward_mode != ForwardMode.DECODE
and int(torch.sum(ret.seq_lens)) > 4096
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged
)
return ret return ret
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = max(len(r.input_ids) for r in batch.reqs)
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE:
self.triton_max_extend_len = None
else:
extend_seq_lens = self.seq_lens - prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_args( def init_flashinfer_handlers(
self, model_runner, prefix_lens, flashinfer_use_ragged
):
update_flashinfer_indices(
self.forward_mode,
model_runner,
self.req_pool_indices,
self.seq_lens,
prefix_lens,
flashinfer_use_ragged=flashinfer_use_ragged,
)
(
self.flashinfer_prefill_wrapper_ragged,
self.flashinfer_prefill_wrapper_paged,
self.flashinfer_decode_wrapper,
self.flashinfer_use_ragged,
) = (
model_runner.flashinfer_prefill_wrapper_ragged,
model_runner.flashinfer_prefill_wrapper_paged,
model_runner.flashinfer_decode_wrapper,
flashinfer_use_ragged,
)
def update_flashinfer_indices(
forward_mode, forward_mode,
model_runner, model_runner,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
prefix_lens, prefix_lens,
flashinfer_decode_wrapper, flashinfer_decode_wrapper=None,
flashinfer_use_ragged=False, flashinfer_use_ragged=False,
): ):
"""Init auxiliary variables for FlashInfer attention backend.""" """Init auxiliary variables for FlashInfer attention backend."""
...@@ -178,7 +253,6 @@ def init_flashinfer_args( ...@@ -178,7 +253,6 @@ def init_flashinfer_args(
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim head_dim = model_runner.model_config.head_dim
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
total_num_tokens = int(torch.sum(seq_lens))
if flashinfer_use_ragged: if flashinfer_use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
...@@ -201,6 +275,10 @@ def init_flashinfer_args( ...@@ -201,6 +275,10 @@ def init_flashinfer_args(
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE: if forward_mode == ForwardMode.DECODE:
# CUDA graph uses different flashinfer_decode_wrapper
if flashinfer_decode_wrapper is None:
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
flashinfer_decode_wrapper.end_forward() flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward( flashinfer_decode_wrapper.begin_forward(
kv_indptr, kv_indptr,
...@@ -238,19 +316,3 @@ def init_flashinfer_args( ...@@ -238,19 +316,3 @@ def init_flashinfer_args(
head_dim, head_dim,
1, 1,
) )
def init_triton_args(forward_mode, seq_lens, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
batch_size = len(seq_lens)
max_seq_len = int(torch.max(seq_lens))
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
if forward_mode == ForwardMode.DECODE:
max_extend_len = None
else:
extend_seq_lens = seq_lens - prefix_lens
max_extend_len = int(torch.max(extend_seq_lens))
return max_seq_len, max_extend_len, start_loc, prefix_lens
...@@ -350,33 +350,18 @@ class ModelRunner: ...@@ -350,33 +350,18 @@ class ModelRunner:
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch) return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.create( input_metadata = InputMetadata.from_schedule_batch(
self, self, batch, ForwardMode.DECODE
forward_mode=ForwardMode.DECODE,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
) )
@torch.inference_mode() @torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch): def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.from_schedule_batch(
self, self, batch, forward_mode=ForwardMode.EXTEND
forward_mode=ForwardMode.EXTEND,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
...@@ -384,24 +369,16 @@ class ModelRunner: ...@@ -384,24 +369,16 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch): def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.from_schedule_batch(
self, self, batch, forward_mode=ForwardMode.EXTEND
forward_mode=ForwardMode.EXTEND,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
) )
return self.model.forward( return self.model.forward(
batch.input_ids, batch.input_ids,
input_metadata.positions, input_metadata.positions,
input_metadata, input_metadata,
batch.pixel_values, input_metadata.pixel_values,
batch.image_sizes, input_metadata.image_sizes,
batch.image_offsets, input_metadata.image_offsets,
) )
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment