Commit 46d20d2d authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix random seed

parent 9c4cd06d
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
......@@ -30,6 +32,22 @@ global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
def get_random_generator(seed):
seed = 1234
random.seed(seed)
os.environ[PYTHONHASHSEED] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
generator = torch.Generator()
return generator
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
......@@ -113,8 +131,7 @@ class SamplerTesterMixin(unittest.TestCase):
@slow
def test_sample(self):
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
generator = get_random_generator(0)
# 1. Load models
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church")
......@@ -163,8 +180,7 @@ class SamplerTesterMixin(unittest.TestCase):
def test_sample_fast(self):
# 1. Load models
generator = torch.Generator()
generator = generator.manual_seed(6694729458485568)
generator = get_random_generator(0)
scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church", timesteps=10)
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
......@@ -214,16 +230,14 @@ class PipelineTesterMixin(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPM.from_pretrained(tmpdirname)
generator = torch.Generator()
generator = generator.manual_seed(669472945848556)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
generator = generator.manual_seed(669472945848556)
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub(self):
......@@ -235,12 +249,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm.noise_scheduler.num_timesteps = 10
ddpm_from_hub.noise_scheduler.num_timesteps = 10
generator = torch.Generator(device=torch_device)
generator = generator.manual_seed(669472945848556)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
generator = generator.manual_seed(669472945848556)
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment