Unverified Commit c3779233 authored by yhyang201's avatar yhyang201 Committed by GitHub
Browse files

[feat] Reduce GPU memory overhead by using weakref (#9673)

parent f84b57c8
from __future__ import annotations
import abc
from typing import TYPE_CHECKING, Set, Type
import weakref
from typing import TYPE_CHECKING, Optional, Set, Type
import torch
......@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
penalizers: Set[Type["_BatchedPenalizer"]],
):
self.vocab_size = vocab_size
self.batch = batch
self._batch_ref = weakref.ref(batch)
self.device = batch.device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
......@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
is_required |= pen_is_required
self.is_required = is_required
@property
def batch(self) -> ScheduleBatch | None:
return self._batch_ref()
@batch.setter
def batch(self, value: Optional[ScheduleBatch]):
if value is None:
self._batch_ref = lambda: None
else:
self._batch_ref = weakref.ref(value)
def reqs(self):
return self.batch.reqs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment