Commit 9c82c32b authored by anton-l's avatar anton-l
Browse files

make style

parent 1a099e5e
...@@ -8,6 +8,9 @@ import PIL.Image ...@@ -8,6 +8,9 @@ import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.utils import logging
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
...@@ -19,10 +22,7 @@ from torchvision.transforms import ( ...@@ -19,10 +22,7 @@ from torchvision.transforms import (
) )
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
from diffusers.modeling_utils import unwrap_model
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
from typing import Optional
from .utils import logging
from huggingface_hub import HfFolder, Repository, whoami
import yaml
import os import os
from pathlib import Path
import shutil import shutil
from pathlib import Path
from typing import Optional
import yaml
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False): ...@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False):
repo.git_pull() repo.git_pull()
# By default, ignore the checkpoint folders # By default, ignore the checkpoint folders
if ( if not os.path.exists(os.path.join(args.output_dir, ".gitignore")) and args.hub_strategy != "all_checkpoints":
not os.path.exists(os.path.join(args.output_dir, ".gitignore"))
and args.hub_strategy != "all_checkpoints"
):
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"]) writer.writelines(["checkpoint-*/"])
return repo return repo
def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: def push_to_hub(
args,
pipeline: DiffusionPipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
**kwargs,
) -> str:
""" """
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters: Parameters:
...@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess ...@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess
return return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together. # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done: if (
blocking
and len(repo.command_queue) > 0
and repo.command_queue[-1] is not None
and not repo.command_queue[-1].is_done
):
repo.command_queue[-1]._process.kill() repo.command_queue[-1]._process.kill()
git_head_commit_url = repo.push_to_hub( git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
)
# push separately the model card to be independent from the rest of the model # push separately the model card to be independent from the rest of the model
create_model_card(args, model_name=model_name) create_model_card(args, model_name=model_name)
try: try:
repo.push_to_hub( repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
)
except EnvironmentError as exc: except EnvironmentError as exc:
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
...@@ -133,10 +140,7 @@ def create_model_card(args, model_name): ...@@ -133,10 +140,7 @@ def create_model_card(args, model_name):
# TODO: replace this placeholder model card generation # TODO: replace this placeholder model card generation
model_card = "" model_card = ""
metadata = { metadata = {"license": "apache-2.0", "tags": ["pytorch", "diffusers"]}
"license": "apache-2.0",
"tags": ["pytorch", "diffusers"]
}
metadata = yaml.dump(metadata, sort_keys=False) metadata = yaml.dump(metadata, sort_keys=False)
if len(metadata) > 0: if len(metadata) > 0:
model_card = f"---\n{metadata}---\n" model_card = f"---\n{metadata}---\n"
......
...@@ -585,4 +585,4 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: ...@@ -585,4 +585,4 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
if hasattr(model, "module"): if hasattr(model, "module"):
return unwrap_model(model.module) return unwrap_model(model.module)
else: else:
return model return model
\ No newline at end of file
...@@ -20,4 +20,4 @@ from .unet import UNetModel ...@@ -20,4 +20,4 @@ from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet from .unet_rl import TemporalUNet
\ No newline at end of file
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
try: try:
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
...@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module): ...@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module):
return out + self.residual_conv(x) return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}') # print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim), SinusoidalPosEmb(dim),
......
...@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def forward_step(self, original_sample, noise, t): def forward_step(self, original_sample, noise, t):
sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_sample return noisy_sample
......
...@@ -33,9 +33,9 @@ from diffusers import ( ...@@ -33,9 +33,9 @@ from diffusers import (
GLIDESuperResUNetModel, GLIDESuperResUNetModel,
LatentDiffusion, LatentDiffusion,
PNDMScheduler, PNDMScheduler,
UNetModel,
UNetLDMModel,
UNetGradTTSModel, UNetGradTTSModel,
UNetLDMModel,
UNetModel,
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetLDMModel model_class = UNetLDMModel
...@@ -378,7 +379,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -378,7 +379,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True) model, loading_info = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy", output_loading_info=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -446,7 +447,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -446,7 +447,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True) model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True)
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -464,7 +465,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -464,7 +465,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
num_features = model.config.n_feats num_features = model.config.n_feats
seq_len = 16 seq_len = 16
noise = torch.randn((1, num_features, seq_len)) noise = torch.randn((1, num_features, seq_len))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment