Unverified Commit 4a292f67 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Add some utility functions (#1671)

parent cd0be748
...@@ -587,6 +587,8 @@ async def benchmark( ...@@ -587,6 +587,8 @@ async def benchmark(
else: else:
print("Initial test run completed. Starting main benchmark run...") print("Initial test run completed. Starting main benchmark run...")
time.sleep(1.5)
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
......
...@@ -392,6 +392,9 @@ class Req: ...@@ -392,6 +392,9 @@ class Req:
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
bid = 0
@dataclass @dataclass
class ScheduleBatch: class ScheduleBatch:
"""Store all inforamtion of a batch.""" """Store all inforamtion of a batch."""
...@@ -828,7 +831,11 @@ class ScheduleBatch: ...@@ -828,7 +831,11 @@ class ScheduleBatch:
else: else:
self.sampling_info.regex_fsms = None self.sampling_info.regex_fsms = None
global bid
bid += 1
return ModelWorkerBatch( return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode, forward_mode=self.forward_mode,
input_ids=self.input_ids, input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices, req_pool_indices=self.req_pool_indices,
...@@ -865,6 +872,8 @@ class ScheduleBatch: ...@@ -865,6 +872,8 @@ class ScheduleBatch:
@dataclass @dataclass
class ModelWorkerBatch: class ModelWorkerBatch:
# The batch id
bid: int
# The forward mode # The forward mode
forward_mode: ForwardMode forward_mode: ForwardMode
# The input ids # The input ids
...@@ -893,3 +902,21 @@ class ModelWorkerBatch: ...@@ -893,3 +902,21 @@ class ModelWorkerBatch:
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
def copy(self):
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(),
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
)
...@@ -710,7 +710,7 @@ class Scheduler: ...@@ -710,7 +710,7 @@ class Scheduler:
next_token_ids next_token_ids
) )
if logits_output: if batch.return_logprob:
# Move logprobs to cpu # Move logprobs to cpu
if logits_output.next_token_logprobs is not None: if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = ( logits_output.next_token_logprobs = (
...@@ -786,7 +786,7 @@ class Scheduler: ...@@ -786,7 +786,7 @@ class Scheduler:
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu # Move logprobs to cpu
if logits_output.next_token_logprobs is not None: if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[ next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device), torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids, next_token_ids,
......
...@@ -202,3 +202,14 @@ class SamplingBatchInfo: ...@@ -202,3 +202,14 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
def copy(self):
return SamplingBatchInfo(
temperatures=self.temperatures,
top_ps=self.top_ps,
top_ks=self.top_ks,
min_ps=self.min_ps,
need_min_p_sampling=self.need_min_p_sampling,
vocab_size=self.vocab_size,
device=self.device,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment