Commit 30b93f8e authored by 王敏's avatar 王敏
Browse files

[fix]解决mtp保存draft prob在例如pd分离场景下的OOM问题

parent 8364249c
......@@ -54,14 +54,22 @@ class DraftProbs(ABC): # type: ignore[call-arg]
# The request id list.
_req_ids: list[str] = []
count = 0
req_id_to_count: dict[str, int] = {}
prune_threshould = 100
def __init__(self, draft_probs, req_ids):
assert len(req_ids) == len(draft_probs)
self.draft_probs = draft_probs
self._req_ids = req_ids
for req_id in req_ids:
self.req_id_to_count[req_id] = self.count
def update(self,
draft_probs: torch.Tensor,
tmp_req_ids: list[str]):
self.count += 1
diff_req_ids = [item for item in self._req_ids if item not in tmp_req_ids]
index = [self._req_ids.index(req_id) for req_id in diff_req_ids]
index_tensor = async_tensor_h2d(
......@@ -71,12 +79,21 @@ class DraftProbs(ABC): # type: ignore[call-arg]
pin_memory=True)
self.draft_probs = self.draft_probs[index_tensor]
self.draft_probs = torch.cat([self.draft_probs, draft_probs])
self._req_ids = diff_req_ids
self._req_ids.extend(tmp_req_ids)
for req_id in tmp_req_ids:
self.req_id_to_count[req_id] = self.count
assert len(self._req_ids) == len(self.draft_probs)
def prune(self, req_ids: list[str]):
if self.count % self.prune_threshould == 0:
for req_id, last_count in self.req_id_to_count.items():
if self.count - last_count >= self.prune_threshould:
req_ids.append(req_id)
self.req_id_to_count = {k: v for k, v in self.req_id_to_count.items() if k not in req_ids}
new_req_ids = [req_id for req_id in self._req_ids if req_id not in req_ids]
if new_req_ids != self._req_ids:
# Batch contents changed - prune removed sequences.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment