"src/vscode:/vscode.git/clone" did not exist on "3552279a23e5ba60d599f4aa8e5b90555b7ce559"
Unverified Commit f1088e0f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix memory leak during abort (#1674)

parent 175afed3
...@@ -17,7 +17,7 @@ import json ...@@ -17,7 +17,7 @@ import json
import multiprocessing import multiprocessing
import os import os
import time import time
from typing import Optional, Tuple from typing import Tuple
import numpy as np import numpy as np
import requests import requests
......
...@@ -775,7 +775,7 @@ class Scheduler: ...@@ -775,7 +775,7 @@ class Scheduler:
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
self.stream_output(batch) self.stream_output(batch.reqs)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result logits_output, next_token_ids = result
...@@ -815,7 +815,7 @@ class Scheduler: ...@@ -815,7 +815,7 @@ class Scheduler:
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
self.stream_output(batch) self.stream_output(batch.reqs)
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
...@@ -894,7 +894,7 @@ class Scheduler: ...@@ -894,7 +894,7 @@ class Scheduler:
return num_input_logprobs return num_input_logprobs
def stream_output(self, batch: ScheduleBatch): def stream_output(self, reqs: List[Req]):
output_rids = [] output_rids = []
output_meta_info = [] output_meta_info = []
output_finished_reason: List[BaseFinishReason] = [] output_finished_reason: List[BaseFinishReason] = []
...@@ -911,7 +911,7 @@ class Scheduler: ...@@ -911,7 +911,7 @@ class Scheduler:
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0 is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
for req in batch.reqs: for req in reqs:
if req.finished() or ( if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1) req.stream and (is_stream_iter or len(req.output_ids) == 1)
): ):
...@@ -1025,8 +1025,9 @@ class Scheduler: ...@@ -1025,8 +1025,9 @@ class Scheduler:
# Delete requests in the running batch # Delete requests in the running batch
if self.running_batch: if self.running_batch:
for req in self.running_batch.reqs: for req in self.running_batch.reqs:
if req.rid == recv_req.rid: if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT() req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req)
break break
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment