Unverified Commit f1088e0f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix memory leak during abort (#1674)

parent 175afed3
......@@ -17,7 +17,7 @@ import json
import multiprocessing
import os
import time
from typing import Optional, Tuple
from typing import Tuple
import numpy as np
import requests
......
......@@ -775,7 +775,7 @@ class Scheduler:
else:
self.tree_cache.cache_unfinished_req(req)
self.stream_output(batch)
self.stream_output(batch.reqs)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
......@@ -815,7 +815,7 @@ class Scheduler:
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.stream_output(batch)
self.stream_output(batch.reqs)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
......@@ -894,7 +894,7 @@ class Scheduler:
return num_input_logprobs
def stream_output(self, batch: ScheduleBatch):
def stream_output(self, reqs: List[Req]):
output_rids = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
......@@ -911,7 +911,7 @@ class Scheduler:
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
for req in batch.reqs:
for req in reqs:
if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1)
):
......@@ -1025,8 +1025,9 @@ class Scheduler:
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid:
if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req)
break
def update_weights(self, recv_req: UpdateWeightReqInput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment