"tests/vscode:/vscode.git/clone" did not exist on "334e6434d255013fbf1b3746f683f0a3d4c7822d"
Unverified Commit 9244f27f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve the style and fix flaky tests (#1584)

parent 2422de51
...@@ -747,7 +747,9 @@ class ScheduleBatch: ...@@ -747,7 +747,9 @@ class ScheduleBatch:
return return
self.reqs = [self.reqs[i] for i in unfinished_indices] self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda") new_indices = torch.tensor(
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
)
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices] self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None self.out_cache_loc = None
......
...@@ -218,6 +218,7 @@ class PrefillAdder: ...@@ -218,6 +218,7 @@ class PrefillAdder:
if not insert_sort: if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied)) self.req_states.append((tokens_left, tokens_occupied))
else: else:
i = 0
for i in range(len(self.req_states)): for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]: if tokens_left <= self.req_states[i][0]:
break break
......
...@@ -144,7 +144,7 @@ class Scheduler: ...@@ -144,7 +144,7 @@ class Scheduler:
) )
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
# Get token and memory info from the tp worker # Get token and memory info from the model worker
( (
self.max_total_num_tokens, self.max_total_num_tokens,
self.max_prefill_tokens, self.max_prefill_tokens,
...@@ -976,7 +976,7 @@ def run_scheduler_process( ...@@ -976,7 +976,7 @@ def run_scheduler_process(
port_args: PortArgs, port_args: PortArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
pipe_writer: multiprocessing.connection.Connection, pipe_writer,
): ):
configure_logger(server_args, prefix=f" TP{tp_rank}") configure_logger(server_args, prefix=f" TP{tp_rank}")
suppress_other_loggers() suppress_other_loggers()
......
...@@ -31,10 +31,13 @@ class ReqToTokenPool: ...@@ -31,10 +31,13 @@ class ReqToTokenPool:
self.size = size self.size = size
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device self.device = device
self.free_slots = list(range(size))
self.req_to_token = torch.empty( self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
) )
self.free_slots = list(range(size))
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]: def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots): if need_size > len(self.free_slots):
......
...@@ -40,7 +40,7 @@ class SamplingBatchInfo: ...@@ -40,7 +40,7 @@ class SamplingBatchInfo:
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs reqs = batch.reqs
with torch.device("cuda"): with batch.input_ids.device:
temperatures = torch.tensor( temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs], [r.sampling_params.temperature for r in reqs],
dtype=torch.float, dtype=torch.float,
......
...@@ -594,7 +594,9 @@ def set_weight_attrs( ...@@ -594,7 +594,9 @@ def set_weight_attrs(
def broadcast_pyobj( def broadcast_pyobj(
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup data: List[Any],
rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None,
): ):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
......
...@@ -26,7 +26,7 @@ class TestTritonAttnBackend(unittest.TestCase): ...@@ -26,7 +26,7 @@ class TestTritonAttnBackend(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
assert output_throughput > 154, f"{output_throughput=}" assert output_throughput > 153, f"{output_throughput=}"
def test_mmlu(self): def test_mmlu(self):
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment