Unverified Commit c1dc2ae6 authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

Fix multi-gpu case for `train_cm_ct_unconditional.py` (#8653)

* Fix multi-gpu case

* Prefer previously created `unwrap_model()` function

For `torch.compile()` generalizability

* `chore: update unwrap_model() function to use accelerator.unwrap_model()`
parent e15a8e7f
......@@ -1195,7 +1195,7 @@ def main(args):
# Resolve the c parameter for the Pseudo-Huber loss
if args.huber_c is None:
args.huber_c = 0.00054 * args.resolution * math.sqrt(unet.config.in_channels)
args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
# Get current number of discretization steps N according to our discretization curriculum
current_discretization_steps = get_discretization_steps(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment