test_scheduler.py 2.21 KB
Newer Older
Patrick von Platen's avatar
Patrick von Platen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


Patrick von Platen's avatar
Patrick von Platen committed
17
import torch
Patrick von Platen's avatar
Patrick von Platen committed
18
import numpy as np
Patrick von Platen's avatar
Patrick von Platen committed
19
20
import unittest
import tempfile
Patrick von Platen's avatar
Patrick von Platen committed
21

Patrick von Platen's avatar
Patrick von Platen committed
22
from diffusers import GaussianDDPMScheduler, DDIMScheduler
Patrick von Platen's avatar
Patrick von Platen committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


torch.backends.cuda.matmul.allow_tf32 = False


class SchedulerCommonTest(unittest.TestCase):

    scheduler_class = None

    @property
    def dummy_image(self):
        batch_size = 4
        num_channels = 3
        height = 8
        width = 8

        image = np.random.rand(batch_size, num_channels, height, width)

        return image

    def get_scheduler_config(self):
        raise NotImplementedError

    def dummy_model(self):
        def model(image, residual, t, *args):
            return (image + residual) * t / (t + 1)

        return model

    def test_from_pretrained_save_pretrained(self):
Patrick von Platen's avatar
update  
Patrick von Platen committed
53
54
55
        image = self.dummy_image
        residual = 0.1 * image

Patrick von Platen's avatar
Patrick von Platen committed
56
57
58
59
60
61
        scheduler_config = self.get_scheduler_config()
        scheduler = self.scheduler_class(scheduler_config())

        with tempfile.TemporaryDirectory() as tmpdirname:
            scheduler.save_pretrained(tmpdirname)
            new_scheduler = self.scheduler_class.from_config(tmpdirname)
Patrick von Platen's avatar
update  
Patrick von Platen committed
62
63
64
65
66

        output = scheduler(residual, image, 1)
        new_output = new_scheduler(residual, image, 1)

        import ipdb; ipdb.set_trace()
Patrick von Platen's avatar
Patrick von Platen committed
67
68
69
70
71
72
73
74
75
76
77
78
79

    def test_step(self):
        scheduler_config = self.get_scheduler_config()
        scheduler = self.scheduler_class(scheduler_config())

        image = self.dummy_image
        residual = 0.1 * image

        output_0 = scheduler(residual, image, 0)
        output_1 = scheduler(residual, image, 1)

        self.assertEqual(output_0.shape, image.shape)
        self.assertEqual(output_0.shape, output_1.shape)