backend_debug.py 638 Bytes
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import backend


class BackendDebug(backend.Backend):
    def __init__(self, image_size=[3, 1024, 1024], **kwargs):
        super(BackendDebug, self).__init__()
        self.image_size = image_size

    def version(self):
        return torch.__version__

    def name(self):
        return "debug-SUT"

    def image_format(self):
        return "NCHW"

    def load(self):
        return self

    def predict(self, prompts):
        images = []
        with torch.no_grad():
            for prompt in prompts:
                image = torch.randn(self.image_size)
                images.append(image)
        return images