Unverified Commit 24e59f53 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

`model_runner` simplify (#329)

parent 75235419
...@@ -407,9 +407,7 @@ class ModelRpcServer: ...@@ -407,9 +407,7 @@ class ModelRpcServer:
prefill_logprobs, prefill_logprobs,
normalized_logprobs, normalized_logprobs,
last_logprobs, last_logprobs,
) = self.model_runner.forward( ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
batch, ForwardMode.EXTEND, batch.return_logprob
)
if prefill_logprobs is not None: if prefill_logprobs is not None:
logprobs = prefill_logprobs.cpu().tolist() logprobs = prefill_logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist()
...@@ -496,9 +494,7 @@ class ModelRpcServer: ...@@ -496,9 +494,7 @@ class ModelRpcServer:
# Forward # Forward
logits, (_, _, last_logprobs) = self.model_runner.forward( logits, (_, _, last_logprobs) = self.model_runner.forward(
batch, batch, ForwardMode.DECODE
ForwardMode.DECODE,
batch.return_logprob,
) )
next_token_ids, _ = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
......
...@@ -367,148 +367,88 @@ class ModelRunner: ...@@ -367,148 +367,88 @@ class ModelRunner:
) )
@torch.inference_mode() @torch.inference_mode()
def forward_prefill( def forward_prefill(self, batch: Batch):
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_logprob,
):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.PREFILL, forward_mode=ForwardMode.PREFILL,
tp_size=self.tp_size, tp_size=self.tp_size,
req_pool_indices=req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=seq_lens, seq_lens=batch.seq_lens,
prefix_lens=prefix_lens, prefix_lens=batch.prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=batch.out_cache_loc,
return_logprob=return_logprob, return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode() @torch.inference_mode()
def forward_extend( def forward_extend(self, batch: Batch):
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_logprob,
):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size, tp_size=self.tp_size,
req_pool_indices=req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=seq_lens, seq_lens=batch.seq_lens,
prefix_lens=prefix_lens, prefix_lens=batch.prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=batch.out_cache_loc,
return_logprob=return_logprob, return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode() @torch.inference_mode()
def forward_decode( def forward_decode(self, batch: Batch):
self,
input_ids,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
out_cache_cont_start,
out_cache_cont_end,
return_logprob,
):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
tp_size=self.tp_size, tp_size=self.tp_size,
req_pool_indices=req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=seq_lens, seq_lens=batch.seq_lens,
prefix_lens=prefix_lens, prefix_lens=batch.prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=batch.out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=batch.out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=batch.out_cache_cont_end,
return_logprob=return_logprob, return_logprob=batch.return_logprob,
)
return self.model.forward(
batch.input_ids, input_metadata.positions, input_metadata
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@torch.inference_mode() @torch.inference_mode()
def forward_extend_multi_modal( def forward_extend_multi_modal(self, batch: Batch):
self,
input_ids,
pixel_values,
image_sizes,
image_offsets,
req_pool_indices,
seq_lens,
prefix_lens,
position_ids_offsets,
out_cache_loc,
return_logprob,
):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
forward_mode=ForwardMode.EXTEND, forward_mode=ForwardMode.EXTEND,
tp_size=self.tp_size, tp_size=self.tp_size,
req_pool_indices=req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=seq_lens, seq_lens=batch.seq_lens,
prefix_lens=prefix_lens, prefix_lens=batch.prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=batch.out_cache_loc,
return_logprob=return_logprob, return_logprob=batch.return_logprob,
) )
return self.model.forward( return self.model.forward(
input_ids, batch.input_ids,
input_metadata.positions, input_metadata.positions,
input_metadata, input_metadata,
pixel_values, batch.pixel_values,
image_sizes, batch.image_sizes,
image_offsets, batch.image_offsets,
) )
def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False): def forward(self, batch: Batch, forward_mode: ForwardMode):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = { return self.forward_extend_multi_modal(batch)
"input_ids": batch.input_ids, elif forward_mode == ForwardMode.DECODE:
"pixel_values": batch.pixel_values, return self.forward_decode(batch)
"image_sizes": batch.image_sizes,
"image_offsets": batch.image_offsets,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
return self.forward_extend_multi_modal(**kwargs)
else:
kwargs = {
"input_ids": batch.input_ids,
"req_pool_indices": batch.req_pool_indices,
"seq_lens": batch.seq_lens,
"prefix_lens": batch.prefix_lens,
"position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc,
"return_logprob": return_logprob,
}
if forward_mode == ForwardMode.DECODE:
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND: elif forward_mode == ForwardMode.EXTEND:
return self.forward_extend(**kwargs) return self.forward_extend(batch)
elif forward_mode == ForwardMode.PREFILL: elif forward_mode == ForwardMode.PREFILL:
return self.forward_prefill(**kwargs) return self.forward_prefill(batch)
else: else:
raise ValueError(f"Invaid forward mode: {forward_mode}") raise ValueError(f"Invaid forward mode: {forward_mode}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment