Unverified Commit 4b27c4a4 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[feat] implement `record_stream` when using CUDA streams during group offloading (#11081)



* implement record_stream for better performance.

* fix

* style.

* merge #11097

* Update src/diffusers/hooks/group_offloading.py
Co-authored-by: default avatarAryan <aryan@huggingface.co>

* fixes

* docstring.

* remaining todos in low_cpu_mem_usage

* tests

* updates to docs.

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
parent 5d49b3e8
...@@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch ...@@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch
# We can utilize the enable_group_offload method for Diffusers model implementations # We can utilize the enable_group_offload method for Diffusers model implementations
pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
# Uncomment the following to also allow recording the current streams.
# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
# For any other model implementations, the apply_group_offloading function can be used # For any other model implementations, the apply_group_offloading function can be used
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level") apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
...@@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s ...@@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html) - The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems. - If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading. - The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this.
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`]. For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
......
...@@ -56,6 +56,7 @@ class ModuleGroup: ...@@ -56,6 +56,7 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None, buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False, non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
onload_self: bool = True, onload_self: bool = True,
) -> None: ) -> None:
...@@ -68,11 +69,14 @@ class ModuleGroup: ...@@ -68,11 +69,14 @@ class ModuleGroup:
self.buffers = buffers or [] self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None self.non_blocking = non_blocking or stream is not None
self.stream = stream self.stream = stream
self.record_stream = record_stream
self.onload_self = onload_self self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage self.low_cpu_mem_usage = low_cpu_mem_usage
self.cpu_param_dict = self._init_cpu_param_dict() self.cpu_param_dict = self._init_cpu_param_dict()
if self.stream is None and self.record_stream:
raise ValueError("`record_stream` cannot be True when `stream` is None.")
def _init_cpu_param_dict(self): def _init_cpu_param_dict(self):
cpu_param_dict = {} cpu_param_dict = {}
if self.stream is None: if self.stream is None:
...@@ -112,6 +116,8 @@ class ModuleGroup: ...@@ -112,6 +116,8 @@ class ModuleGroup:
def onload_(self): def onload_(self):
r"""Onloads the group of modules to the onload_device.""" r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None
if self.stream is not None: if self.stream is not None:
# Wait for previous Host->Device transfer to complete # Wait for previous Host->Device transfer to complete
self.stream.synchronize() self.stream.synchronize()
...@@ -122,14 +128,22 @@ class ModuleGroup: ...@@ -122,14 +128,22 @@ class ModuleGroup:
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in group_module.buffers(): for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
for param in self.parameters: for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
else: else:
for group_module in self.modules: for group_module in self.modules:
...@@ -143,11 +157,14 @@ class ModuleGroup: ...@@ -143,11 +157,14 @@ class ModuleGroup:
for buffer in self.buffers: for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
def offload_(self): def offload_(self):
r"""Offloads the group of modules to the offload_device.""" r"""Offloads the group of modules to the offload_device."""
if self.stream is not None: if self.stream is not None:
torch.cuda.current_stream().synchronize() if not self.record_stream:
torch.cuda.current_stream().synchronize()
for group_module in self.modules: for group_module in self.modules:
for param in group_module.parameters(): for param in group_module.parameters():
param.data = self.cpu_param_dict[param] param.data = self.cpu_param_dict[param]
...@@ -331,6 +348,7 @@ def apply_group_offloading( ...@@ -331,6 +348,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
...@@ -378,6 +396,10 @@ def apply_group_offloading( ...@@ -378,6 +396,10 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`): use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer. overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`): low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
...@@ -417,11 +439,24 @@ def apply_group_offloading( ...@@ -417,11 +439,24 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level( _apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
elif offload_type == "leaf_level": elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level( _apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage module=module,
offload_device=offload_device,
onload_device=onload_device,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
else: else:
raise ValueError(f"Unsupported offload_type: {offload_type}") raise ValueError(f"Unsupported offload_type: {offload_type}")
...@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level( ...@@ -434,6 +469,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
...@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level( ...@@ -453,6 +489,14 @@ def _apply_group_offloading_block_level(
stream (`torch.cuda.Stream`, *optional*): stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer. for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
""" """
# Create module groups for ModuleList and Sequential blocks # Create module groups for ModuleList and Sequential blocks
...@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level( ...@@ -475,6 +519,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0], onload_leader=current_modules[0],
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None, onload_self=stream is None,
) )
...@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level( ...@@ -512,6 +557,7 @@ def _apply_group_offloading_block_level(
buffers=buffers, buffers=buffers,
non_blocking=False, non_blocking=False,
stream=None, stream=None,
record_stream=False,
onload_self=True, onload_self=True,
) )
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
...@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level( ...@@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device, onload_device: torch.device,
non_blocking: bool, non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None, stream: Optional[torch.cuda.Stream] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False, low_cpu_mem_usage: bool = False,
) -> None: ) -> None:
r""" r"""
...@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level( ...@@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level(
stream (`torch.cuda.Stream`, *optional*): stream (`torch.cuda.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer. for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
""" """
# Create module groups for leaf modules and apply group offloading hooks # Create module groups for leaf modules and apply group offloading hooks
...@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level( ...@@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule, onload_leader=submodule,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
...@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level( ...@@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers, buffers=buffers,
non_blocking=non_blocking, non_blocking=non_blocking,
stream=stream, stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
...@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level( ...@@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level(
buffers=None, buffers=None,
non_blocking=False, non_blocking=False,
stream=None, stream=None,
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True, onload_self=True,
) )
......
...@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None, num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False, non_blocking: bool = False,
use_stream: bool = False, use_stream: bool = False,
record_stream: bool = False,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
) -> None: ) -> None:
r""" r"""
...@@ -594,6 +595,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -594,6 +595,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group, num_blocks_per_group,
non_blocking, non_blocking,
use_stream, use_stream,
record_stream,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
......
...@@ -1525,8 +1525,9 @@ class ModelTesterMixin: ...@@ -1525,8 +1525,9 @@ class ModelTesterMixin:
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
) )
@parameterized.expand([False, True])
@require_torch_gpu @require_torch_gpu
def test_group_offloading(self): def test_group_offloading(self, record_stream):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -1566,7 +1567,9 @@ class ModelTesterMixin: ...@@ -1566,7 +1567,9 @@ class ModelTesterMixin:
torch.manual_seed(0) torch.manual_seed(0)
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model) output_with_group_offloading4 = run_forward(model)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment