Unverified Commit ef2ea33c authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

VQ-diffusion (#658)



* Changes for VQ-diffusion VQVAE

Add specify dimension of embeddings to VQModel:
`VQModel` will by default set the dimension of embeddings to the number
of latent channels. The VQ-diffusion VQVAE has a smaller
embedding dimension, 128, than number of latent channels, 256.

Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down
unet block helpers. VQ-diffusion's VQVAE uses those two block types.

* Changes for VQ-diffusion transformer

Modify attention.py so SpatialTransformer can be used for
VQ-diffusion's transformer.

SpatialTransformer:
- Can now operate over discrete inputs (classes of vector embeddings) as well as continuous.
- `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs
- modified forward pass to take optional timestep embeddings

ImagePositionalEmbeddings:
- added to provide positional embeddings to discrete inputs for latent pixels

BasicTransformerBlock:
- norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings
- modified forward pass to take optional timestep embeddings

CrossAttention:
- now may optionally take a bias parameter for its query, key, and value linear layers

FeedForward:
- Internal layers are now configurable

ApproximateGELU:
- Activation function in VQ-diffusion's feedforward layer

AdaLayerNorm:
- Norm layer modified to incorporate timestep embeddings

* Add VQ-diffusion scheduler

* Add VQ-diffusion pipeline

* Add VQ-diffusion convert script to diffusers

* Add VQ-diffusion dummy objects

* Add VQ-diffusion markdown docs

* Add VQ-diffusion tests

* some renaming

* some fixes

* more renaming

* correct

* fix typo

* correct weights

* finalize

* fix tests

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* finish

* finish

* up
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarAnton Lozhkov <aglozhkov@gmail.com>
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 269109db
...@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject): ...@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet1DModel(metaclass=DummyObject): class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -257,6 +272,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject): ...@@ -257,6 +272,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMScheduler(metaclass=DummyObject): class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -407,6 +437,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject): ...@@ -407,6 +437,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EMAModel(metaclass=DummyObject): class EMAModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.utils import load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def num_embed(self):
return 12
@property
def num_embeds_ada_norm(self):
return 12
@property
def dummy_vqvae(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
num_vq_embeddings=self.num_embed,
vq_embed_dim=3,
)
return model
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
@property
def dummy_transformer(self):
torch.manual_seed(0)
height = 12
width = 12
model_kwargs = {
"attention_bias": True,
"cross_attention_dim": 32,
"attention_head_dim": height * width,
"num_attention_heads": 1,
"num_vector_embeds": self.num_embed,
"num_embeds_ada_norm": self.num_embeds_ada_norm,
"norm_num_groups": 32,
"sample_size": width,
"activation_fn": "geglu-approximate",
}
model = Transformer2DModel(**model_kwargs)
return model
def test_vq_diffusion(self):
device = "cpu"
vqvae = self.dummy_vqvae
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
transformer = self.dummy_transformer
scheduler = VQDiffusionScheduler(self.num_embed)
pipe = VQDiffusionPipeline(
vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
prompt = "teddy bear playing in the pool"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = pipe(
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 24, 24, 3)
expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_vq_diffusion(self):
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
"teddy bear playing in the pool",
truncation_rate=0.86,
num_images_per_prompt=1,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2
...@@ -18,8 +18,9 @@ import unittest ...@@ -18,8 +18,9 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from torch import nn
from diffusers.models.attention import AttentionBlock, SpatialTransformer from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.utils import torch_device from diffusers.utils import torch_device
...@@ -235,7 +236,7 @@ class AttentionBlockTests(unittest.TestCase): ...@@ -235,7 +236,7 @@ class AttentionBlockTests(unittest.TestCase):
num_head_channels=1, num_head_channels=1,
rescale_output_factor=1.0, rescale_output_factor=1.0,
eps=1e-6, eps=1e-6,
num_groups=32, norm_num_groups=32,
).to(torch_device) ).to(torch_device)
with torch.no_grad(): with torch.no_grad():
attention_scores = attentionBlock(sample) attention_scores = attentionBlock(sample)
...@@ -259,7 +260,7 @@ class AttentionBlockTests(unittest.TestCase): ...@@ -259,7 +260,7 @@ class AttentionBlockTests(unittest.TestCase):
channels=512, channels=512,
rescale_output_factor=1.0, rescale_output_factor=1.0,
eps=1e-6, eps=1e-6,
num_groups=32, norm_num_groups=32,
).to(torch_device) ).to(torch_device)
with torch.no_grad(): with torch.no_grad():
attention_scores = attentionBlock(sample) attention_scores = attentionBlock(sample)
...@@ -273,22 +274,22 @@ class AttentionBlockTests(unittest.TestCase): ...@@ -273,22 +274,22 @@ class AttentionBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class SpatialTransformerTests(unittest.TestCase): class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self): def test_spatial_transformer_default(self):
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device) sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = SpatialTransformer( spatial_transformer_block = Transformer2DModel(
in_channels=32, in_channels=32,
n_heads=1, num_attention_heads=1,
d_head=32, attention_head_dim=32,
dropout=0.0, dropout=0.0,
context_dim=None, cross_attention_dim=None,
).to(torch_device) ).to(torch_device)
with torch.no_grad(): with torch.no_grad():
attention_scores = spatial_transformer_block(sample) attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, 32, 64, 64) assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:] output_slice = attention_scores[0, -1, -3:, -3:]
...@@ -298,22 +299,22 @@ class SpatialTransformerTests(unittest.TestCase): ...@@ -298,22 +299,22 @@ class SpatialTransformerTests(unittest.TestCase):
) )
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_context_dim(self): def test_spatial_transformer_cross_attention_dim(self):
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 64, 64, 64).to(torch_device) sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = SpatialTransformer( spatial_transformer_block = Transformer2DModel(
in_channels=64, in_channels=64,
n_heads=2, num_attention_heads=2,
d_head=32, attention_head_dim=32,
dropout=0.0, dropout=0.0,
context_dim=64, cross_attention_dim=64,
).to(torch_device) ).to(torch_device)
with torch.no_grad(): with torch.no_grad():
context = torch.randn(1, 4, 64).to(torch_device) context = torch.randn(1, 4, 64).to(torch_device)
attention_scores = spatial_transformer_block(sample, context) attention_scores = spatial_transformer_block(sample, context).sample
assert attention_scores.shape == (1, 64, 64, 64) assert attention_scores.shape == (1, 64, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:] output_slice = attention_scores[0, -1, -3:, -3:]
...@@ -323,6 +324,44 @@ class SpatialTransformerTests(unittest.TestCase): ...@@ -323,6 +324,44 @@ class SpatialTransformerTests(unittest.TestCase):
) )
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_timestep(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_embeds_ada_norm = 5
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
in_channels=64,
num_attention_heads=2,
attention_head_dim=32,
dropout=0.0,
cross_attention_dim=64,
num_embeds_ada_norm=num_embeds_ada_norm,
).to(torch_device)
with torch.no_grad():
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample
assert attention_scores_1.shape == (1, 64, 64, 64)
assert attention_scores_2.shape == (1, 64, 64, 64)
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
expected_slice_1 = torch.tensor(
[-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device
)
expected_slice_2 = torch.tensor(
[-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device
)
assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3)
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3)
def test_spatial_transformer_dropout(self): def test_spatial_transformer_dropout(self):
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -330,18 +369,18 @@ class SpatialTransformerTests(unittest.TestCase): ...@@ -330,18 +369,18 @@ class SpatialTransformerTests(unittest.TestCase):
sample = torch.randn(1, 32, 64, 64).to(torch_device) sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = ( spatial_transformer_block = (
SpatialTransformer( Transformer2DModel(
in_channels=32, in_channels=32,
n_heads=2, num_attention_heads=2,
d_head=16, attention_head_dim=16,
dropout=0.3, dropout=0.3,
context_dim=None, cross_attention_dim=None,
) )
.to(torch_device) .to(torch_device)
.eval() .eval()
) )
with torch.no_grad(): with torch.no_grad():
attention_scores = spatial_transformer_block(sample) attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, 32, 64, 64) assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:] output_slice = attention_scores[0, -1, -3:, -3:]
...@@ -350,3 +389,107 @@ class SpatialTransformerTests(unittest.TestCase): ...@@ -350,3 +389,107 @@ class SpatialTransformerTests(unittest.TestCase):
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device [-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
) )
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_embed = 5
sample = torch.randint(0, num_embed, (1, 32)).to(torch_device)
spatial_transformer_block = (
Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
num_vector_embeds=num_embed,
sample_size=16,
)
.to(torch_device)
.eval()
)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, num_embed - 1, 32)
output_slice = attention_scores[0, -2:, -3:]
expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_default_norm_layers(self):
spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32)
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
def test_spatial_transformer_ada_norm_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
num_embeds_ada_norm=5,
)
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
def test_spatial_transformer_default_ff_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
)
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32
inner_dim = 128
# First dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
# NOTE: inner_dim * 2 because GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2
# Second dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
def test_spatial_transformer_geglu_approx_ff_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
activation_fn="geglu-approximate",
)
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32
inner_dim = 128
# First dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim
# Second dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
def test_spatial_transformer_attention_bias(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True
)
assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None
assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None
assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None
...@@ -19,6 +19,7 @@ from typing import Dict, List, Tuple ...@@ -19,6 +19,7 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
...@@ -29,6 +30,7 @@ from diffusers import ( ...@@ -29,6 +30,7 @@ from diffusers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
VQDiffusionScheduler,
) )
from diffusers.utils import torch_device from diffusers.utils import torch_device
...@@ -85,12 +87,18 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -85,12 +87,18 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step) time_step = float(time_step)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config) scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, time_step)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
...@@ -122,12 +130,18 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -122,12 +130,18 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step) time_step = float(time_step)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, time_step)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname)
...@@ -154,15 +168,21 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -154,15 +168,21 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None) num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample timestep = 1
residual = 0.1 * sample if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
timestep = 1 if scheduler_class == VQDiffusionScheduler:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler): num_vec_classes = scheduler_config["num_vec_classes"]
timestep = float(timestep) sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname) scheduler.save_config(tmpdirname)
...@@ -200,6 +220,12 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -200,6 +220,12 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep_0)
else:
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -255,6 +281,12 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -255,6 +281,12 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep)
else:
sample = self.dummy_sample sample = self.dummy_sample
residual = 0.1 * sample residual = 0.1 * sample
...@@ -284,19 +316,23 @@ class SchedulerCommonTest(unittest.TestCase): ...@@ -284,19 +316,23 @@ class SchedulerCommonTest(unittest.TestCase):
for scheduler_class in self.scheduler_classes: for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config() scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config) scheduler = scheduler_class(**scheduler_config)
if scheduler_class != VQDiffusionScheduler:
self.assertTrue( self.assertTrue(
hasattr(scheduler, "init_noise_sigma"), hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
) )
self.assertTrue( self.assertTrue(
hasattr(scheduler, "scale_model_input"), hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`", f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`",
) )
self.assertTrue( self.assertTrue(
hasattr(scheduler, "step"), hasattr(scheduler, "step"),
f"{scheduler_class} does not implement a required class method `step(...)`", f"{scheduler_class} does not implement a required class method `step(...)`",
) )
if scheduler_class != VQDiffusionScheduler:
sample = self.dummy_sample sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0) scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape) self.assertEqual(sample.shape, scaled_sample.shape)
...@@ -1238,3 +1274,53 @@ class IPNDMSchedulerTest(SchedulerCommonTest): ...@@ -1238,3 +1274,53 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 2540529) < 10 assert abs(result_mean.item() - 2540529) < 10
class VQDiffusionSchedulerTest(SchedulerCommonTest):
scheduler_classes = (VQDiffusionScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_vec_classes": 4097,
"num_train_timesteps": 100,
}
config.update(**kwargs)
return config
def dummy_sample(self, num_vec_classes):
batch_size = 4
height = 8
width = 8
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
return sample
@property
def dummy_sample_deter(self):
assert False
def dummy_model(self, num_vec_classes):
def model(sample, t, *args):
batch_size, num_latent_pixels = sample.shape
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
return_value = F.log_softmax(logits.double(), dim=1).float()
return return_value
return model
def test_timesteps(self):
for timesteps in [2, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_num_vec_classes(self):
for num_vec_classes in [5, 100, 1000, 4000]:
self.check_over_configs(num_vec_classes=num_vec_classes)
def test_time_indices(self):
for t in [0, 50, 99]:
self.check_over_forward(time_step=t)
def test_add_noise_device(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment