Unverified Commit 1096f88e authored by Yue Wu's avatar Yue Wu Committed by GitHub
Browse files

sampling bug fix in diffusers tutorial "basic_training.md" (#8223)

sampling bug fix in basic_training.md

In the diffusers basic training tutorial, setting the manual seed argument (generator=torch.manual_seed(config.seed)) in the pipeline call inside evaluate() function rewinds the dataloader shuffling, leading to overfitting due to the model seeing same sequence of training examples after every evaluation call. Using generator=torch.Generator(device='cpu').manual_seed(config.seed) avoids this.
parent cef4a512
...@@ -260,7 +260,7 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [ ...@@ -260,7 +260,7 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [
... # The default pipeline output type is `List[PIL.Image]` ... # The default pipeline output type is `List[PIL.Image]`
... images = pipeline( ... images = pipeline(
... batch_size=config.eval_batch_size, ... batch_size=config.eval_batch_size,
... generator=torch.manual_seed(config.seed), ... generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop
... ).images ... ).images
... # Make a grid out of the images ... # Make a grid out of the images
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment