Unverified Commit 6d5ef87e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DDPM] Make DDPM work (#88)

* up

* finish

* uP
parent e7fe901e
......@@ -50,6 +50,8 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.training_utils import EMAModel
# 1. LDM
def test_output_pretrained_ldm_dummy():
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
model.eval()
......@@ -86,9 +88,38 @@ def test_output_pretrained_ldm():
import ipdb; ipdb.set_trace()
# To see the how the final model should look like
test_output_pretrained_ldm_dummy()
test_output_pretrained_ldm()
# => this is the architecture in which the model should be saved in the new format
# -> verify new repo with the following tests (in `test_modeling_utils.py`)
# - test_ldm_uncond (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetLDMModelTests)
#test_output_pretrained_ldm_dummy()
#test_output_pretrained_ldm()
# 2. DDPM
def get_model(model_id):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy", ldm=True)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
output = model(noise, time_step)
print(model)
# Repos to convert and port to google (part of https://github.com/hojonathanho/diffusion)
# - fusing/ddpm_dummy
# - fusing/ddpm-cifar10
# - https://huggingface.co/fusing/ddpm-lsun-church-ema
# - https://huggingface.co/fusing/ddpm-lsun-bedroom-ema
# - https://huggingface.co/fusing/ddpm-celeba-hq
# tests to make sure to pass
# - test_ddim_cifar10, test_ddim_lsun, test_ddpm_cifar10, test_ddim_cifar10 (in PipelineTesterMixin)
# - test_output_pretrained ( in UNetModelTests)
# e.g.
get_model("fusing/ddpm-cifar10")
......@@ -492,15 +492,17 @@ class ModelMixin(torch.nn.Module):
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if False:
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
" identical (initializing a BertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
......@@ -513,9 +515,9 @@ class ModelMixin(torch.nn.Module):
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
" training."
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
" without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
......@@ -527,8 +529,8 @@ class ModelMixin(torch.nn.Module):
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
......
......@@ -166,188 +166,6 @@ class AttentionBlock(nn.Module):
return result
class AttentionBlockNew_2(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_head_channels=1,
num_groups=32,
encoder_channels=None,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
self.n_heads = channels // num_head_channels
self.num_head_size = num_head_channels
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
# ------------------------- new -----------------------
num_heads = self.n_heads
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
# ------------------------- new -----------------------
def set_weight(self, attn_layer):
self.norm.weight.data = attn_layer.norm.weight.data
self.norm.bias.data = attn_layer.norm.bias.data
self.qkv.weight.data = attn_layer.qkv.weight.data
self.qkv.bias.data = attn_layer.qkv.bias.data
self.proj.weight.data = attn_layer.proj.weight.data
self.proj.bias.data = attn_layer.proj.bias.data
if hasattr(attn_layer, "q"):
module = attn_layer
qkv_weight = torch.cat([module.q.weight.data, module.k.weight.data, module.v.weight.data], dim=0)[
:, :, :, 0
]
qkv_bias = torch.cat([module.q.bias.data, module.k.bias.data, module.v.bias.data], dim=0)
self.qkv.weight.data = qkv_weight
self.qkv.bias.data = qkv_bias
proj_out = zero_module(nn.Conv1d(self.channels, self.channels, 1))
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
proj_out.bias.data = module.proj_out.bias.data
self.proj = proj_out
self.set_weights_2(attn_layer)
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.n_heads, self.num_head_size)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
def set_weights_2(self, attn_layer):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
qkv_weight = attn_layer.qkv.weight.data.reshape(self.n_heads, 3 * self.channels // self.n_heads, self.channels)
qkv_bias = attn_layer.qkv.bias.data.reshape(self.n_heads, 3 * self.channels // self.n_heads)
q_w, k_w, v_w = qkv_weight.split(self.channels // self.n_heads, dim=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.n_heads, dim=1)
self.query.weight.data = q_w.reshape(-1, self.channels)
self.key.weight.data = k_w.reshape(-1, self.channels)
self.value.weight.data = v_w.reshape(-1, self.channels)
self.query.bias.data = q_b.reshape(-1)
self.key.bias.data = k_b.reshape(-1)
self.value.bias.data = v_b.reshape(-1)
self.proj_attn.weight.data = attn_layer.proj.weight.data[:, :, 0]
self.proj_attn.bias.data = attn_layer.proj.bias.data
def forward_2(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.channels // self.n_heads)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# compute attention output
context_states = torch.matmul(attention_probs, value_states)
context_states = context_states.permute(0, 2, 1, 3).contiguous()
new_context_states_shape = context_states.size()[:-2] + (self.channels,)
context_states = context_states.view(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def forward(self, x, encoder_out=None):
b, c, *spatial = x.shape
hid_states = self.norm(x).view(b, c, -1)
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
k = torch.cat([ek, k], dim=-1)
v = torch.cat([ev, v], dim=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v)
h = a.reshape(bs, -1, length)
h = self.proj(h)
h = h.reshape(b, c, *spatial)
result = x + h
result = result / self.rescale_output_factor
result_2 = self.forward_2(x)
print((result - result_2).abs().sum())
return result_2
class AttentionBlockNew(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
......@@ -387,7 +205,7 @@ class AttentionBlockNew(nn.Module):
self.proj_attn = zero_module(nn.Linear(channels, channels, 1))
def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection_shape = projection.size()[:-1] + (self.num_heads, self.num_head_size)
new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
......@@ -434,6 +252,18 @@ class AttentionBlockNew(nn.Module):
self.group_norm.weight.data = attn_layer.norm.weight.data
self.group_norm.bias.data = attn_layer.norm.bias.data
if hasattr(attn_layer, "q"):
self.query.weight.data = attn_layer.q.weight.data[:, :, 0, 0]
self.key.weight.data = attn_layer.k.weight.data[:, :, 0, 0]
self.value.weight.data = attn_layer.v.weight.data[:, :, 0, 0]
self.query.bias.data = attn_layer.q.bias.data
self.key.bias.data = attn_layer.k.bias.data
self.value.bias.data = attn_layer.v.bias.data
self.proj_attn.weight.data = attn_layer.proj_out.weight.data[:, :, 0, 0]
self.proj_attn.bias.data = attn_layer.proj_out.bias.data
else:
qkv_weight = attn_layer.qkv.weight.data.reshape(
self.num_heads, 3 * self.channels // self.num_heads, self.channels
)
......
......@@ -87,12 +87,21 @@ class Downsample2D(nn.Module):
self.conv = conv
def forward(self, x):
# print("use_conv", self.use_conv)
# print("padding", self.padding)
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
return self.conv(x)
# print("x", x.abs().sum())
self.hey = x
assert x.shape[1] == self.channels
x = self.conv(x)
self.yas = x
# print("x", x.abs().sum())
return x
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
......
......@@ -177,9 +177,7 @@ class UNetModel(ModelMixin, ConfigMixin):
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
print("hs", hs[-1].abs().sum())
h = self.mid_new(hs[-1], temb)
print("h", h.abs().sum())
# upsampling
for i_level in reversed(range(self.num_resolutions)):
......
......@@ -51,6 +51,7 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
......@@ -186,6 +187,7 @@ class UNetResAttnDownBlock2D(nn.Module):
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
):
super().__init__()
......@@ -224,7 +226,11 @@ class UNetResAttnDownBlock2D(nn.Module):
if add_downsample:
self.downsamplers = nn.ModuleList(
[Downsample2D(in_channels, use_conv=True, out_channels=out_channels, padding=1, name="op")]
[
Downsample2D(
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
......
......@@ -94,25 +94,6 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
):
super().__init__()
# DELETE if statements if not necessary anymore
# DDPM
if ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
flip_sin_to_cos = False
downscale_freq_shift = 1
resnet_eps = 1e-6
block_channels = (32, 64)
down_blocks = (
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
)
up_blocks = ("UNetResUpBlock2D", "UNetResAttnUpBlock2D")
downsample_padding = 0
num_head_channels = 64
# register all __init__ params with self.register
self.register_to_config(
image_size=image_size,
......@@ -250,6 +231,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
out_channels,
)
if ddpm:
out_channels = out_ch
image_size = resolution
block_channels = [x * ch for x in ch_mult]
conv_resample = resamp_with_conv
self.init_for_ddpm(
ch_mult,
ch,
......@@ -290,13 +275,11 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# append to tuple
down_block_res_samples += res_samples
print("sample", sample.abs().sum())
# 4. mid block
if self.config.ddpm:
sample = self.mid_new_2(sample, emb)
else:
sample = self.mid(sample, emb)
print("sample", sample.abs().sum())
# 5. up blocks
for upsample_block in self.upsample_blocks:
......@@ -373,8 +356,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
elif self.config.ddpm:
# =============== SET WEIGHTS ===============
# =============== TIME ======================
self.time_embed[0] = self.temb.dense[0]
self.time_embed[2] = self.temb.dense[1]
self.time_embedding.linear_1.weight.data = self.temb.dense[0].weight.data
self.time_embedding.linear_1.bias.data = self.temb.dense[0].bias.data
self.time_embedding.linear_2.weight.data = self.temb.dense[1].weight.data
self.time_embedding.linear_2.bias.data = self.temb.dense[1].bias.data
for i, block in enumerate(self.down):
if hasattr(block, "downsample"):
......@@ -391,6 +376,23 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.mid_new_2.resnets[1].set_weight(self.mid.block_2)
self.mid_new_2.attentions[0].set_weight(self.mid.attn_1)
for i, block in enumerate(self.up):
k = len(self.up) - 1 - i
if hasattr(block, "upsample"):
self.upsample_blocks[k].upsamplers[0].conv.weight.data = block.upsample.conv.weight.data
self.upsample_blocks[k].upsamplers[0].conv.bias.data = block.upsample.conv.bias.data
if hasattr(block, "block") and len(block.block) > 0:
for j in range(self.num_res_blocks + 1):
self.upsample_blocks[k].resnets[j].set_weight(block.block[j])
if hasattr(block, "attn") and len(block.attn) > 0:
for j in range(self.num_res_blocks + 1):
self.upsample_blocks[k].attentions[j].set_weight(block.attn[j])
self.conv_norm_out.weight.data = self.norm_out.weight.data
self.conv_norm_out.bias.data = self.norm_out.bias.data
self.remove_ddpm()
def init_for_ddpm(
self,
ch_mult,
......@@ -685,3 +687,10 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
del self.middle_block
del self.output_blocks
del self.out
def remove_ddpm(self):
del self.temb
del self.down
del self.mid_new
del self.up
del self.norm_out
......@@ -40,7 +40,6 @@ from diffusers import (
ScoreSdeVpPipeline,
ScoreSdeVpScheduler,
UNetLDMModel,
UNetModel,
UNetUnconditionalModel,
VQModel,
)
......@@ -209,7 +208,7 @@ class ModelTesterMixin:
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetModel
model_class = UNetUnconditionalModel
@property
def dummy_input(self):
......@@ -234,15 +233,24 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
init_dict = {
"ch": 32,
"ch_mult": (1, 2),
"block_channels": (32, 64),
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"),
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"),
"num_head_channels": None,
"out_channels": 3,
"in_channels": 3,
"num_res_blocks": 2,
"attn_resolutions": (16,),
"resolution": 32,
"image_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_from_pretrained_hub(self):
model, loading_info = UNetModel.from_pretrained("fusing/ddpm_dummy", output_loading_info=True)
model, loading_info = UNetUnconditionalModel.from_pretrained(
"fusing/ddpm_dummy", output_loading_info=True, ddpm=True
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
......@@ -252,27 +260,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None"
def test_output_pretrained(self):
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
model.eval()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution)
time_step = torch.tensor([10])
with torch.no_grad():
output = model(noise, time_step)
output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
print("Original success!!!")
model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy", ddpm=True)
model.eval()
......@@ -849,7 +836,9 @@ class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
model = UNetUnconditionalModel(
ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32, ddpm=True
)
schedular = DDPMScheduler(timesteps=10)
ddpm = DDPMPipeline(model, schedular)
......@@ -888,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ddpm_cifar10(self):
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDPMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
......@@ -901,7 +890,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.5712, -0.6215, -0.5953, -0.5438, -0.4775, -0.4539, -0.5172, -0.4872, -0.5105]
[-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_ddim_lsun(self):
model_id = "fusing/ddpm-lsun-bedroom-ema"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
ddpm = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor(
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
......@@ -909,7 +919,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ddim_cifar10(self):
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
......@@ -929,7 +939,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_pndm_cifar10(self):
model_id = "fusing/ddpm-cifar10"
unet = UNetModel.from_pretrained(model_id)
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment