Unverified Commit 9918d13e authored by MQY's avatar MQY Committed by GitHub
Browse files

fix(training_utils): wrap device in list for DiffusionPipeline (#12178)



- Modify offload_models function to handle DiffusionPipeline correctly
- Ensure compatibility with both single and multiple module inputs
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent e8246604
...@@ -339,7 +339,8 @@ def offload_models( ...@@ -339,7 +339,8 @@ def offload_models(
original_devices = [next(m.parameters()).device for m in modules] original_devices = [next(m.parameters()).device for m in modules]
else: else:
assert len(modules) == 1 assert len(modules) == 1
original_devices = modules[0].device # For DiffusionPipeline, wrap the device in a list to make it iterable
original_devices = [modules[0].device]
# move to target device # move to target device
for m in modules: for m in modules:
m.to(device) m.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment