"vscode:/vscode.git/clone" did not exist on "308e52a37e63b31dad299ad7642ea9ba8de60333"
Unverified Commit 2854a5ea authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the overhead due to penalizer in bench_latency (#1496)

parent 42a2d82b
...@@ -260,7 +260,7 @@ def correctness_test( ...@@ -260,7 +260,7 @@ def correctness_test(
# Decode # Decode
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
for _ in range(bench_args.output_len[0]): for _ in range(bench_args.output_len[0] - 1):
next_token_ids, _ = decode(next_token_ids, batch, model_runner) next_token_ids, _ = decode(next_token_ids, batch, model_runner)
for i in range(len(reqs)): for i in range(len(reqs)):
output_ids[i].append(next_token_ids[i]) output_ids[i].append(next_token_ids[i])
...@@ -311,7 +311,7 @@ def latency_test_run_once( ...@@ -311,7 +311,7 @@ def latency_test_run_once(
# Decode # Decode
decode_latencies = [] decode_latencies = []
for i in range(output_len): for i in range(output_len - 1):
torch.cuda.synchronize() torch.cuda.synchronize()
tic = time.time() tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner) next_token_ids, _ = decode(next_token_ids, batch, model_runner)
......
...@@ -429,7 +429,7 @@ class ScheduleBatch: ...@@ -429,7 +429,7 @@ class ScheduleBatch:
def prepare_for_extend(self, vocab_size: int): def prepare_for_extend(self, vocab_size: int):
self.forward_mode = ForwardMode.EXTEND self.forward_mode = ForwardMode.EXTEND
bs = self.batch_size() bs = len(self.reqs)
reqs = self.reqs reqs = self.reqs
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
extend_num_tokens = sum(len(ids) for ids in input_ids) extend_num_tokens = sum(len(ids) for ids in input_ids)
...@@ -509,7 +509,7 @@ class ScheduleBatch: ...@@ -509,7 +509,7 @@ class ScheduleBatch:
self.extend_logprob_start_lens_cpu.extend([0] * running_bs) self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
def check_decode_mem(self): def check_decode_mem(self):
bs = self.batch_size() bs = len(self.reqs)
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool.available_size() >= bs:
return True return True
...@@ -680,14 +680,12 @@ class ScheduleBatch: ...@@ -680,14 +680,12 @@ class ScheduleBatch:
r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
for r in self.reqs for r in self.reqs
] ]
else:
self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
self.seq_lens.add_(1) self.seq_lens.add_(1)
# Alloc mem # Alloc mem
bs = self.batch_size() bs = len(self.reqs)
self.out_cache_loc = self.alloc_token_slots(bs) self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[
......
...@@ -215,6 +215,7 @@ class ModelTpServer: ...@@ -215,6 +215,7 @@ class ModelTpServer:
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False self.do_not_get_new_batch = False
@torch.inference_mode()
def exposed_step(self, recv_reqs: List): def exposed_step(self, recv_reqs: List):
try: try:
# Recv requests # Recv requests
...@@ -246,7 +247,6 @@ class ModelTpServer: ...@@ -246,7 +247,6 @@ class ModelTpServer:
self.out_pyobjs = [] self.out_pyobjs = []
return ret return ret
@torch.inference_mode()
def forward_step(self): def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None: if self.do_not_get_new_batch and self.current_inflight_req is None:
new_batch = None new_batch = None
......
...@@ -97,14 +97,12 @@ class InputMetadata: ...@@ -97,14 +97,12 @@ class InputMetadata:
self.modalities = [r.modalities for r in reqs] self.modalities = [r.modalities for r in reqs]
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
if True: if True:
self.positions = self.seq_lens - 1 self.positions = self.seq_lens - 1
else: else:
# Deprecated # Deprecated
self.positions = (self.seq_lens - 1) + position_ids_offsets self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
else: else:
if True: if True:
self.positions = torch.tensor( self.positions = torch.tensor(
...@@ -119,7 +117,7 @@ class InputMetadata: ...@@ -119,7 +117,7 @@ class InputMetadata:
) )
else: else:
# Deprecated # Deprecated
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
self.positions = torch.tensor( self.positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
......
...@@ -467,7 +467,6 @@ class ModelRunner: ...@@ -467,7 +467,6 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
@torch.inference_mode()
def forward_decode(self, batch: ScheduleBatch): def forward_decode(self, batch: ScheduleBatch):
if self.server_args.lora_paths is not None: if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch) self.lora_manager.prepare_lora_batch(batch)
...@@ -481,7 +480,6 @@ class ModelRunner: ...@@ -481,7 +480,6 @@ class ModelRunner:
batch.input_ids, input_metadata.positions, input_metadata batch.input_ids, input_metadata.positions, input_metadata
) )
@torch.inference_mode()
def forward_extend(self, batch: ScheduleBatch): def forward_extend(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch) input_metadata = InputMetadata.from_schedule_batch(self, batch)
if self.server_args.lora_paths is not None: if self.server_args.lora_paths is not None:
...@@ -500,7 +498,6 @@ class ModelRunner: ...@@ -500,7 +498,6 @@ class ModelRunner:
get_embedding=True, get_embedding=True,
) )
@torch.inference_mode()
def forward_extend_multi_modal(self, batch: ScheduleBatch): def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch) input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward( return self.model.forward(
......
...@@ -45,7 +45,7 @@ def normal_text(args): ...@@ -45,7 +45,7 @@ def normal_text(args):
"The capital of the United Kindom is", "The capital of the United Kindom is",
"Today is a sunny day and I like", "Today is a sunny day and I like",
] ]
max_new_tokens = 17 max_new_tokens = 16
torch.cuda.set_device(0) torch.cuda.set_device(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment