"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "30e5e81d58eb9c3979c07e6626bae89c1df8c0e1"
Unverified Commit 5ce4814a authored by Saurabh Misra's avatar Saurabh Misra Committed by GitHub
Browse files

️ Speed up method `AutoencoderKLWan.clear_cache` by 886% (#11665)

* 

️ Speed up method `AutoencoderKLWan.clear_cache` by 886%

**Key optimizations:**
- Compute the number of `WanCausalConv3d` modules in each model (`encoder`/`decoder`) **only once during initialization**, store in `self._cached_conv_counts`. This removes unnecessary repeated tree traversals at every `clear_cache` call, which was the main bottleneck (from profiling).
- The internal helper `_count_conv3d_fast` is optimized via a generator expression with `sum` for efficiency.

All comments from the original code are preserved, except for updated or removed local docstrings/comments relevant to changed lines.  
**Function signatures and outputs remain unchanged.**

* Apply style fixes

* Apply suggestions from code review
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>

* Apply style fixes

---------
Co-authored-by: default avatarcodeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarAryan <contact.aryanvs@gmail.com>
Co-authored-by: default avatarAseem Saxena <aseem.bits@gmail.com>
parent 1bc6f3dc
...@@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -749,6 +749,16 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = 192 self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192 self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
if self.decoder is not None
else 0,
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
if self.encoder is not None
else 0,
}
def enable_tiling( def enable_tiling(
self, self,
tile_sample_min_height: Optional[int] = None, tile_sample_min_height: Optional[int] = None,
...@@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -801,18 +811,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False self.use_slicing = False
def clear_cache(self): def clear_cache(self):
def _count_conv3d(model): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
count = 0 self._conv_num = self._cached_conv_counts["decoder"]
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = [0] self._conv_idx = [0]
self._feat_map = [None] * self._conv_num self._feat_map = [None] * self._conv_num
# cache encode # cache encode
self._enc_conv_num = _count_conv3d(self.encoder) self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0] self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num self._enc_feat_map = [None] * self._enc_conv_num
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment