Commit 9c82c32b authored by anton-l's avatar anton-l
Browse files

make style

parent 1a099e5e
...@@ -8,6 +8,9 @@ import PIL.Image ...@@ -8,6 +8,9 @@ import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.utils import logging
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
Compose, Compose,
...@@ -19,10 +22,7 @@ from torchvision.transforms import ( ...@@ -19,10 +22,7 @@ from torchvision.transforms import (
) )
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup from transformers import get_linear_schedule_with_warmup
from diffusers.modeling_utils import unwrap_model
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
from typing import Optional
from .utils import logging
from huggingface_hub import HfFolder, Repository, whoami
import yaml
import os import os
from pathlib import Path
import shutil import shutil
from pathlib import Path
from typing import Optional
import yaml
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami
from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False): ...@@ -68,17 +70,21 @@ def init_git_repo(args, at_init: bool = False):
repo.git_pull() repo.git_pull()
# By default, ignore the checkpoint folders # By default, ignore the checkpoint folders
if ( if not os.path.exists(os.path.join(args.output_dir, ".gitignore")) and args.hub_strategy != "all_checkpoints":
not os.path.exists(os.path.join(args.output_dir, ".gitignore"))
and args.hub_strategy != "all_checkpoints"
):
with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
writer.writelines(["checkpoint-*/"]) writer.writelines(["checkpoint-*/"])
return repo return repo
def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: def push_to_hub(
args,
pipeline: DiffusionPipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
**kwargs,
) -> str:
""" """
Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
Parameters: Parameters:
...@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess ...@@ -108,18 +114,19 @@ def push_to_hub(args, pipeline: DiffusionPipeline, repo: Repository, commit_mess
return return
# Cancel any async push in progress if blocking=True. The commits will all be pushed together. # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
if blocking and len(repo.command_queue) > 0 and repo.command_queue[-1] is not None and not repo.command_queue[-1].is_done: if (
blocking
and len(repo.command_queue) > 0
and repo.command_queue[-1] is not None
and not repo.command_queue[-1].is_done
):
repo.command_queue[-1]._process.kill() repo.command_queue[-1]._process.kill()
git_head_commit_url = repo.push_to_hub( git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
)
# push separately the model card to be independent from the rest of the model # push separately the model card to be independent from the rest of the model
create_model_card(args, model_name=model_name) create_model_card(args, model_name=model_name)
try: try:
repo.push_to_hub( repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
)
except EnvironmentError as exc: except EnvironmentError as exc:
logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
...@@ -133,10 +140,7 @@ def create_model_card(args, model_name): ...@@ -133,10 +140,7 @@ def create_model_card(args, model_name):
# TODO: replace this placeholder model card generation # TODO: replace this placeholder model card generation
model_card = "" model_card = ""
metadata = { metadata = {"license": "apache-2.0", "tags": ["pytorch", "diffusers"]}
"license": "apache-2.0",
"tags": ["pytorch", "diffusers"]
}
metadata = yaml.dump(metadata, sort_keys=False) metadata = yaml.dump(metadata, sort_keys=False)
if len(metadata) > 0: if len(metadata) > 0:
model_card = f"---\n{metadata}---\n" model_card = f"---\n{metadata}---\n"
......
...@@ -5,6 +5,7 @@ import math ...@@ -5,6 +5,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
try: try:
import einops import einops
from einops.layers.torch import Rearrange from einops.layers.torch import Rearrange
...@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module): ...@@ -103,7 +104,7 @@ class ResidualTemporalBlock(nn.Module):
return out + self.residual_conv(x) return out + self.residual_conv(x)
class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module):
def __init__( def __init__(
self, self,
horizon, horizon,
...@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module): ...@@ -118,7 +119,6 @@ class TemporalUNet(ModelMixin, ConfigMixin): #(nn.Module):
in_out = list(zip(dims[:-1], dims[1:])) in_out = list(zip(dims[:-1], dims[1:]))
# print(f'[ models/temporal ] Channel dimensions: {in_out}') # print(f'[ models/temporal ] Channel dimensions: {in_out}')
time_dim = dim time_dim = dim
self.time_mlp = nn.Sequential( self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim), SinusoidalPosEmb(dim),
......
...@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample return pred_prev_sample
def forward_step(self, original_sample, noise, t): def forward_step(self, original_sample, noise, t):
sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5 sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5 sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_sample return noisy_sample
......
...@@ -33,9 +33,9 @@ from diffusers import ( ...@@ -33,9 +33,9 @@ from diffusers import (
GLIDESuperResUNetModel, GLIDESuperResUNetModel,
LatentDiffusion, LatentDiffusion,
PNDMScheduler, PNDMScheduler,
UNetModel,
UNetLDMModel,
UNetGradTTSModel, UNetGradTTSModel,
UNetLDMModel,
UNetModel,
) )
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
...@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -342,6 +342,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetLDMModel model_class = UNetLDMModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment