utils_eval.py 868 Bytes
Newer Older
chenpangpang's avatar
chenpangpang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch


def process_inference_batch(cfg_scale, batch, model, with_uncondition_extra=False):
    for k in batch.keys():
        if isinstance(batch[k], torch.Tensor):
            batch[k] = batch[k].to(model.device, dtype=model.dtype)
    z, cond, x_rec = model.get_batch_input(
        batch,
        random_drop_training_conditions=False,
        return_reconstructed_target_images=True,
    )
    # batch_size = x_rec.shape[0]
    # Get unconditioned embedding for classifier-free guidance sampling
    if cfg_scale != 1.0:
        uc = model.get_unconditional_dict_for_sampling(batch, cond, x_rec)
    else:
        uc = None

    if with_uncondition_extra:
        uc_extra = model.get_unconditional_dict_for_sampling(
            batch, cond, x_rec, is_extra=True
        )
        return cond, uc, uc_extra, x_rec
    else:
        return cond, uc, x_rec