Unverified Commit c3779233 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

[feat] Reduce GPU memory overhead by using weakref (#9673)

parent f84b57c8
from __future__ import annotations from __future__ import annotations
import abc import abc
from typing import TYPE_CHECKING, Set, Type import weakref
from typing import TYPE_CHECKING, Optional, Set, Type
import torch import torch
...@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator: ...@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
penalizers: Set[Type["_BatchedPenalizer"]], penalizers: Set[Type["_BatchedPenalizer"]],
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch = batch self._batch_ref = weakref.ref(batch)
self.device = batch.device self.device = batch.device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers} self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
...@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator: ...@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
is_required |= pen_is_required is_required |= pen_is_required
self.is_required = is_required self.is_required = is_required
@property
def batch(self) -> ScheduleBatch | None:
return self._batch_ref()
@batch.setter
def batch(self, value: Optional[ScheduleBatch]):
if value is None:
self._batch_ref = lambda: None
else:
self._batch_ref = weakref.ref(value)
def reqs(self): def reqs(self):
return self.batch.reqs return self.batch.reqs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment