Unverified Commit cab4064c authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix workspace manager `_current_workspaces` size (#38853)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 062f1a2d
...@@ -31,7 +31,7 @@ _manager: "WorkspaceManager | None" = None ...@@ -31,7 +31,7 @@ _manager: "WorkspaceManager | None" = None
class WorkspaceManager: class WorkspaceManager:
"""Manager for workspace allocation. """Manager for workspace allocation.
Manages workspace buffers for DBO (Dual Batch Overlap) execution. Manages one workspace buffer per active ubatch slot.
Can be locked to prevent further growth during execution. Can be locked to prevent further growth during execution.
""" """
...@@ -39,7 +39,9 @@ class WorkspaceManager: ...@@ -39,7 +39,9 @@ class WorkspaceManager:
self._device = device self._device = device
# Cache num ubatches at init based on configuration (default to 1) # Cache num ubatches at init based on configuration (default to 1)
self._num_ubatches = num_ubatches if num_ubatches is not None else 1 self._num_ubatches = num_ubatches if num_ubatches is not None else 1
self._current_workspaces: list[torch.Tensor | None] = [None, None] self._current_workspaces: list[torch.Tensor | None] = [
None
] * self._num_ubatches
self._locked: bool = False self._locked: bool = False
@staticmethod @staticmethod
...@@ -224,7 +226,7 @@ def init_workspace_manager( ...@@ -224,7 +226,7 @@ def init_workspace_manager(
Args: Args:
device: The device to allocate workspace on. device: The device to allocate workspace on.
num_ubatches: Number of micro-batches. Defaults to 1. num_ubatches: Number of workspace ubatch slots. Defaults to 1.
""" """
global _manager global _manager
if _manager is not None: if _manager is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment