Unverified Commit 24e59f53 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

`model_runner` simplify (#329)

parent 75235419
......@@ -407,9 +407,7 @@ class ModelRpcServer:
prefill_logprobs,
normalized_logprobs,
last_logprobs,
) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_logprob
)
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
if prefill_logprobs is not None:
logprobs = prefill_logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist()
......@@ -496,9 +494,7 @@ class ModelRpcServer:
# Forward
logits, (_, _, last_logprobs) = self.model_runner.forward(
batch,
ForwardMode.DECODE,
batch.return_logprob,
batch, ForwardMode.DECODE
)
next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist()
......
......@@ -367,148 +367,88 @@ class ModelRunner:
)
@torch.inference_mode()
def forward_prefill(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_logprob,
):
def forward_prefill(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_logprob=return_logprob,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_extend(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_logprob,
):
def forward_extend(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_logprob=return_logprob,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_decode(
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
return_logprob,
):
def forward_decode(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
return_logprob=return_logprob,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
out_cache_cont_start=batch.out_cache_cont_start,
out_cache_cont_end=batch.out_cache_cont_end,
return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
)
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode()
def forward_extend_multi_modal(
self,
input_ids,
pixel_values,
image_sizes,
image_offsets,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_logprob,
):
def forward_extend_multi_modal(self, batch: Batch):
input_metadata = InputMetadata.create(
self,
forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
return_logprob=return_logprob,
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc,
return_logprob=batch.return_logprob,
)
return self.model.forward(
input_ids,
batch.input_ids,
input_metadata.positions,
input_metadata,
pixel_values,
image_sizes,
image_offsets,
batch.pixel_values,
batch.image_sizes,
batch.image_offsets,
)
def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
def forward(self, batch: Batch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = {
"input_ids": batch.input_ids,
"pixel_values": batch.pixel_values,
"image_sizes": batch.image_sizes,
"image_offsets": batch.image_offsets,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
"input_ids": batch.input_ids,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
if forward_mode == ForwardMode.DECODE:
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
return self.forward_extend_multi_modal(batch)
elif forward_mode == ForwardMode.DECODE:
return self.forward_decode(batch)
elif forward_mode == ForwardMode.EXTEND:
return self.forward_extend(**kwargs)
return self.forward_extend(batch)
elif forward_mode == ForwardMode.PREFILL:
return self.forward_prefill(**kwargs)
return self.forward_prefill(batch)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment