Unverified Commit 9244f27f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Improve the style and fix flaky tests (#1584)

parent 2422de51
......@@ -747,7 +747,9 @@ class ScheduleBatch:
return
self.reqs = [self.reqs[i] for i in unfinished_indices]
new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
new_indices = torch.tensor(
unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
)
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None
......
......@@ -218,6 +218,7 @@ class PrefillAdder:
if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
i = 0
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
......
......@@ -144,7 +144,7 @@ class Scheduler:
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
# Get token and memory info from the tp worker
# Get token and memory info from the model worker
(
self.max_total_num_tokens,
self.max_prefill_tokens,
......@@ -976,7 +976,7 @@ def run_scheduler_process(
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
pipe_writer: multiprocessing.connection.Connection,
pipe_writer,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
suppress_other_loggers()
......
......@@ -31,10 +31,13 @@ class ReqToTokenPool:
self.size = size
self.max_context_len = max_context_len
self.device = device
self.free_slots = list(range(size))
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
......
......@@ -40,7 +40,7 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
with torch.device("cuda"):
with batch.input_ids.device:
temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
......
......@@ -594,7 +594,9 @@ def set_weight_attrs(
def broadcast_pyobj(
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
data: List[Any],
rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
......
......@@ -26,7 +26,7 @@ class TestTritonAttnBackend(unittest.TestCase):
)
if is_in_ci():
assert output_throughput > 154, f"{output_throughput=}"
assert output_throughput > 153, f"{output_throughput=}"
def test_mmlu(self):
model = DEFAULT_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment