Unverified Commit 2c1ed50f authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Provide option to reduce CPU RAM usage in Group Offload (#11106)

* update

* update

* clean up
parent 15ad97f7
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
import torch import torch
...@@ -56,7 +56,7 @@ class ModuleGroup: ...@@ -56,7 +56,7 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False, non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None, low_cpu_mem_usage=False,
onload_self: bool = True, onload_self: bool = True,
) -> None: ) -> None:
self.modules = modules self.modules = modules
...@@ -64,15 +64,50 @@ class ModuleGroup: ...@@ -64,15 +64,50 @@ class ModuleGroup:
self.onload_device = onload_device self.onload_device = onload_device
self.offload_leader = offload_leader self.offload_leader = offload_leader
self.onload_leader = onload_leader self.onload_leader = onload_leader
self.parameters = parameters self.parameters = parameters or []
self.buffers = buffers self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None self.non_blocking = non_blocking or stream is not None
self.stream = stream self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
if self.stream is not None and self.cpu_param_dict is None: self.cpu_param_dict = self._init_cpu_param_dict()
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
return cpu_param_dict
for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
return cpu_param_dict
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
yield pinned_dict
finally:
pinned_dict = None
def onload_(self): def onload_(self):
r"""Onloads the group of modules to the onload_device.""" r"""Onloads the group of modules to the onload_device."""
...@@ -82,15 +117,30 @@ class ModuleGroup: ...@@ -82,15 +117,30 @@ class ModuleGroup:
self.stream.synchronize() self.stream.synchronize()
with context: with context:
for group_module in self.modules: if self.stream is not None:
for param in group_module.parameters(): with self._pinned_memory_tensors() as pinned_memory:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) for group_module in self.modules:
for buffer in group_module.buffers(): for param in group_module.parameters():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None: for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
else:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
for param in self.parameters: for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
...@@ -101,21 +151,18 @@ class ModuleGroup: ...@@ -101,21 +151,18 @@ class ModuleGroup:
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
param.data = self.cpu_param_dict[param] param.data = self.cpu_param_dict[param]
if self.parameters is not None: for param in self.parameters:
for param in self.parameters: param.data = self.cpu_param_dict[param]
param.data = self.cpu_param_dict[param] for buffer in self.buffers:
if self.buffers is not None: buffer.data = self.cpu_param_dict[buffer]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else: else:
for group_module in self.modules: for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking) group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None: for param in self.parameters:
for param in self.parameters: param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) for buffer in self.buffers:
if self.buffers is not None: buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
class GroupOffloadingHook(ModelHook): class GroupOffloadingHook(ModelHook):
...@@ -284,6 +331,7 @@ def apply_group_offloading( ...@@ -284,6 +331,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None: ) -> None:
r""" r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
...@@ -365,10 +413,12 @@ def apply_group_offloading( ...@@ -365,10 +413,12 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level( _apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
) )
elif offload_type == "leaf_level": elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream) _apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
else: else:
raise ValueError(f"Unsupported offload_type: {offload_type}") raise ValueError(f"Unsupported offload_type: {offload_type}")
...@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level( ...@@ -380,6 +430,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
...@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level( ...@@ -400,11 +451,6 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer. for overlapping computation and data transfer.
""" """
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)
# Create module groups for ModuleList and Sequential blocks # Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set() modules_with_group_offloading = set()
unmatched_modules = [] unmatched_modules = []
...@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level( ...@@ -425,7 +471,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0], onload_leader=current_modules[0],
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
cpu_param_dict=cpu_param_dict, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None, onload_self=stream is None,
) )
matched_module_groups.append(group) matched_module_groups.append(group)
...@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level( ...@@ -462,7 +508,6 @@ def _apply_group_offloading_block_level(
buffers=buffers, buffers=buffers,
non_blocking=False, non_blocking=False,
stream=None, stream=None,
cpu_param_dict=None,
onload_self=True, onload_self=True,
) )
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
...@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level( ...@@ -475,6 +520,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
...@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level( ...@@ -497,11 +543,6 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer. for overlapping computation and data transfer.
""" """
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
cpu_param_dict = _get_pinned_cpu_param_dict(module)
# Create module groups for leaf modules and apply group offloading hooks # Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set() modules_with_group_offloading = set()
for name, submodule in module.named_modules(): for name, submodule in module.named_modules():
...@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level( ...@@ -515,7 +556,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule, onload_leader=submodule,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
cpu_param_dict=cpu_param_dict, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_group_offloading_hook(submodule, group, None) _apply_group_offloading_hook(submodule, group, None)
...@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level( ...@@ -560,7 +601,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers, buffers=buffers,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
cpu_param_dict=cpu_param_dict, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_group_offloading_hook(parent_module, group, None) _apply_group_offloading_hook(parent_module, group, None)
...@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level( ...@@ -579,7 +620,7 @@ def _apply_group_offloading_leaf_level(
buffers=None, buffers=None,
non_blocking=False, non_blocking=False,
stream=None, stream=None,
cpu_param_dict=None, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
_apply_lazy_group_offloading_hook(module, unmatched_group, None) _apply_lazy_group_offloading_hook(module, unmatched_group, None)
...@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook( ...@@ -616,17 +657,6 @@ def _apply_lazy_group_offloading_hook(
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
def _get_pinned_cpu_param_dict(module: torch.nn.Module) -> Dict[torch.nn.Parameter, torch.Tensor]:
cpu_param_dict = {}
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict[param] = param.data
for buffer in module.buffers():
buffer.data = buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = buffer.data
return cpu_param_dict
def _gather_parameters_with_no_group_offloading_parent( def _gather_parameters_with_no_group_offloading_parent(
module: torch.nn.Module, modules_with_group_offloading: Set[str] module: torch.nn.Module, modules_with_group_offloading: Set[str]
) -> List[torch.nn.Parameter]: ) -> List[torch.nn.Parameter]:
......
...@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None: ) -> None:
r""" r"""
Activates group offloading for the current model. Activates group offloading for the current model.
...@@ -584,7 +585,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -584,7 +585,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues." f"open an issue at https://github.com/huggingface/diffusers/issues."
) )
apply_group_offloading( apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
def save_pretrained( def save_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment