test_base_ddpm.py 492 Bytes
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmgen.models.diffusions import UniformTimeStepSampler


def test_uniform_sampler():
    sampler = UniformTimeStepSampler(10)
    timesteps = sampler(2)
    assert timesteps.shape == torch.Size([
        2,
    ])
    assert timesteps.max() < 10 and timesteps.min() >= 0

    timesteps = sampler.__call__(2)
    assert timesteps.shape == torch.Size([
        2,
    ])
    assert timesteps.max() < 10 and timesteps.min() >= 0