Commit fdf57325 authored by pythongosssss's avatar pythongosssss
Browse files

Merge remote-tracking branch 'origin/master' into tiled-progress

parents 27df7410 93c64afa
import pygit2
from datetime import datetime
import sys
def pull(repo, remote_name='origin', branch='master'):
for remote in repo.remotes:
if remote.name == remote_name:
remote.fetch()
remote_master_id = repo.lookup_reference('refs/remotes/origin/%s' % (branch)).target
merge_result, _ = repo.merge_analysis(remote_master_id)
# Up to date, do nothing
if merge_result & pygit2.GIT_MERGE_ANALYSIS_UP_TO_DATE:
return
# We can just fastforward
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_FASTFORWARD:
repo.checkout_tree(repo.get(remote_master_id))
try:
master_ref = repo.lookup_reference('refs/heads/%s' % (branch))
master_ref.set_target(remote_master_id)
except KeyError:
repo.create_branch(branch, repo.get(remote_master_id))
repo.head.set_target(remote_master_id)
elif merge_result & pygit2.GIT_MERGE_ANALYSIS_NORMAL:
repo.merge(remote_master_id)
if repo.index.conflicts is not None:
for conflict in repo.index.conflicts:
print('Conflicts found in:', conflict[0].path)
raise AssertionError('Conflicts, ahhhhh!!')
user = repo.default_signature
tree = repo.index.write_tree()
commit = repo.create_commit('HEAD',
user,
user,
'Merge!',
tree,
[repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging.
repo.state_cleanup()
else:
raise AssertionError('Unknown merge analysis result')
repo = pygit2.Repository(str(sys.argv[1]))
ident = pygit2.Signature('comfyui', 'comfy@ui')
try:
print("stashing current changes")
repo.stash(ident)
except KeyError:
print("nothing to stash")
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name))
repo.branches.local.create(backup_branch_name, repo.head.peel())
print("checking out master branch")
branch = repo.lookup_branch('master')
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
pull(repo)
print("Done!")
..\python_embeded\python.exe .\update.py ..\ComfyUI\
pause
..\python_embeded\python.exe .\update.py ..\ComfyUI\ ..\python_embeded\python.exe .\update.py ..\ComfyUI\
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2
pause pause
HOW TO RUN:
if you have a NVIDIA gpu:
run_nvidia_gpu.bat
To run it in slow CPU mode:
run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
RECOMMENDED WAY TO UPDATE:
To update the ComfyUI code: update\update_comfyui.bat
To update ComfyUI with the python dependencies:
update\update_comfyui_and_python_dependencies.bat
.\python_embeded\python.exe -s ComfyUI\main.py --cpu --windows-standalone-build
pause
...@@ -17,7 +17,7 @@ jobs: ...@@ -17,7 +17,7 @@ jobs:
- shell: bash - shell: bash
run: | run: |
python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers==0.0.19.dev516 --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip wheel --no-cache-dir torch torchvision torchaudio xformers --extra-index-url https://download.pytorch.org/whl/cu118 -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/* python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic echo installed basic
ls -lah temp_wheel_dir ls -lah temp_wheel_dir
......
...@@ -19,21 +19,21 @@ jobs: ...@@ -19,21 +19,21 @@ jobs:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.10.9' python-version: '3.11.3'
- shell: bash - shell: bash
run: | run: |
cd .. cd ..
cp -r ComfyUI ComfyUI_copy cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.10.9/python-3.10.9-embed-amd64.zip -o python_embeded.zip curl https://www.python.org/ftp/python/3.11.3/python-3.11.3-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded unzip python_embeded.zip -d python_embeded
cd python_embeded cd python_embeded
echo 'import site' >> ./python310._pth echo 'import site' >> ./python311._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu118 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu121 -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python310._pth sed -i '1i../ComfyUI' ./python311._pth
cd .. cd ..
...@@ -46,6 +46,8 @@ jobs: ...@@ -46,6 +46,8 @@ jobs:
mkdir update mkdir update
cp -r ComfyUI/.ci/update_windows/* ./update/ cp -r ComfyUI/.ci/update_windows/* ./update/
cp -r ComfyUI/.ci/windows_base_files/* ./ cp -r ComfyUI/.ci/windows_base_files/* ./
cp -r ComfyUI/.ci/nightly/update_windows/* ./update/
cp -r ComfyUI/.ci/nightly/windows_base_files/* ./
cd .. cd ..
......
...@@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend. ...@@ -7,6 +7,8 @@ A powerful and modular stable diffusion GUI and backend.
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out: This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
### [Installing ComfyUI](#installing)
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x and SD2.x - Fully supports SD1.x and SD2.x
...@@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ...@@ -17,6 +19,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models.
- Embeddings/Textual inversion - Embeddings/Textual inversion
- [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/)
- Loading full workflows (with seeds) from generated PNG files. - Loading full workflows (with seeds) from generated PNG files.
- Saving/Loading workflows as Json files. - Saving/Loading workflows as Json files.
- Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones.
......
...@@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the ...@@ -10,6 +10,7 @@ parser.add_argument("--output-directory", type=str, default=None, help="Set the
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.") parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.") parser.add_argument("--dont-upcast-attention", action="store_true", help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.")
......
...@@ -712,7 +712,7 @@ class UniPC: ...@@ -712,7 +712,7 @@ class UniPC:
def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform', def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
atol=0.0078, rtol=0.05, corrector=False, atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
): ):
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start t_T = self.noise_schedule.T if t_start is None else t_start
...@@ -723,7 +723,7 @@ class UniPC: ...@@ -723,7 +723,7 @@ class UniPC:
# timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
assert timesteps.shape[0] - 1 == steps assert timesteps.shape[0] - 1 == steps
# with torch.no_grad(): # with torch.no_grad():
for step_index in trange(steps): for step_index in trange(steps, disable=disable_pbar):
if self.noise_mask is not None: if self.noise_mask is not None:
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
if step_index == 0: if step_index == 0:
...@@ -766,6 +766,8 @@ class UniPC: ...@@ -766,6 +766,8 @@ class UniPC:
if model_x is None: if model_x is None:
model_x = self.model_fn(x, vec_t) model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x model_prev_list[-1] = model_x
if callback is not None:
callback(step_index, model_prev_list[-1], x)
else: else:
raise NotImplementedError() raise NotImplementedError()
if denoise_to_zero: if denoise_to_zero:
...@@ -833,7 +835,7 @@ def expand_dims(v, dims): ...@@ -833,7 +835,7 @@ def expand_dims(v, dims):
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=None, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
to_zero = False to_zero = False
if sigmas[-1] == 0: if sigmas[-1] == 0:
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
...@@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex ...@@ -877,7 +879,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
order = min(3, len(timesteps) - 1) order = min(3, len(timesteps) - 1)
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
if not to_zero: if not to_zero:
x /= ns.marginal_alpha(timesteps[-1]) x /= ns.marginal_alpha(timesteps[-1])
return x return x
...@@ -81,6 +81,7 @@ class DDIMSampler(object): ...@@ -81,6 +81,7 @@ class DDIMSampler(object):
extra_args=None, extra_args=None,
to_zero=True, to_zero=True,
end_step=None, end_step=None,
disable_pbar=False,
**kwargs **kwargs
): ):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
...@@ -103,7 +104,8 @@ class DDIMSampler(object): ...@@ -103,7 +104,8 @@ class DDIMSampler(object):
denoise_function=denoise_function, denoise_function=denoise_function,
extra_args=extra_args, extra_args=extra_args,
to_zero=to_zero, to_zero=to_zero,
end_step=end_step end_step=end_step,
disable_pbar=disable_pbar
) )
return samples, intermediates return samples, intermediates
...@@ -185,7 +187,7 @@ class DDIMSampler(object): ...@@ -185,7 +187,7 @@ class DDIMSampler(object):
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None): ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
...@@ -204,7 +206,7 @@ class DDIMSampler(object): ...@@ -204,7 +206,7 @@ class DDIMSampler(object):
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps") # print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
......
...@@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): ...@@ -76,12 +76,14 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
support it as an extra input. support it as an extra input.
""" """
def forward(self, x, emb, context=None, transformer_options={}): def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self: for layer in self:
if isinstance(layer, TimestepBlock): if isinstance(layer, TimestepBlock):
x = layer(x, emb) x = layer(x, emb)
elif isinstance(layer, SpatialTransformer): elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options) x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else: else:
x = layer(x) x = layer(x)
return x return x
...@@ -105,14 +107,20 @@ class Upsample(nn.Module): ...@@ -105,14 +107,20 @@ class Upsample(nn.Module):
if use_conv: if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x): def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" if output_shape is not None:
) shape[1] = output_shape[3]
shape[2] = output_shape[4]
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]
x = F.interpolate(x, size=shape, mode="nearest")
if self.use_conv: if self.use_conv:
x = self.conv(x) x = self.conv(x)
return x return x
...@@ -813,9 +821,14 @@ class UNetModel(nn.Module): ...@@ -813,9 +821,14 @@ class UNetModel(nn.Module):
ctrl = control['output'].pop() ctrl = control['output'].pop()
if ctrl is not None: if ctrl is not None:
hsp += ctrl hsp += ctrl
h = th.cat([h, hsp], dim=1) h = th.cat([h, hsp], dim=1)
del hsp del hsp
h = module(h, emb, context, transformer_options) if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype) h = h.type(x.dtype)
if self.predict_codebook_ids: if self.predict_codebook_ids:
return self.id_predictor(h) return self.id_predictor(h)
......
...@@ -20,8 +20,23 @@ total_vram_available_mb = -1 ...@@ -20,8 +20,23 @@ total_vram_available_mb = -1
accelerate_enabled = False accelerate_enabled = False
xpu_available = False xpu_available = False
directml_enabled = False
if args.directml is not None:
import torch_directml
directml_enabled = True
device_index = args.directml
if device_index < 0:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
try: try:
import torch import torch
if directml_enabled:
total_vram = 4097 #TODO
else:
try: try:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if torch.xpu.is_available(): if torch.xpu.is_available():
...@@ -217,6 +232,10 @@ def unload_if_low_vram(model): ...@@ -217,6 +232,10 @@ def unload_if_low_vram(model):
def get_torch_device(): def get_torch_device():
global xpu_available global xpu_available
global directml_enabled
if directml_enabled:
global directml_device
return directml_device
if vram_state == VRAMState.MPS: if vram_state == VRAMState.MPS:
return torch.device("mps") return torch.device("mps")
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
...@@ -234,8 +253,14 @@ def get_autocast_device(dev): ...@@ -234,8 +253,14 @@ def get_autocast_device(dev):
def xformers_enabled(): def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU: if vram_state == VRAMState.CPU:
return False return False
if xpu_available:
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
...@@ -251,6 +276,7 @@ def pytorch_attention_enabled(): ...@@ -251,6 +276,7 @@ def pytorch_attention_enabled():
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled
if dev is None: if dev is None:
dev = get_torch_device() dev = get_torch_device()
...@@ -258,7 +284,10 @@ def get_free_memory(dev=None, torch_free_too=False): ...@@ -258,7 +284,10 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
if xpu_available: if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev) mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total mem_free_torch = mem_free_total
else: else:
...@@ -293,9 +322,14 @@ def mps_mode(): ...@@ -293,9 +322,14 @@ def mps_mode():
def should_use_fp16(): def should_use_fp16():
global xpu_available global xpu_available
global directml_enabled
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled:
return False
if cpu_mode() or mps_mode() or xpu_available: if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ? return False #TODO ?
......
import torch
import comfy.model_management
import comfy.samplers
import math
def prepare_noise(latent_image, seed, skip=0):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
for _ in range(skip):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
return noise
def prepare_mask(noise_mask, shape, device):
"""ensures noise mask is of proper dimensions"""
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
noise_mask = noise_mask.round()
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
if noise_mask.shape[0] < shape[0]:
noise_mask = noise_mask.repeat(math.ceil(shape[0] / noise_mask.shape[0]), 1, 1, 1)[:shape[0]]
noise_mask = noise_mask.to(device)
return noise_mask
def broadcast_cond(cond, batch, device):
"""broadcasts conditioning to the batch size"""
copy = []
for p in cond:
t = p[0]
if t.shape[0] < batch:
t = torch.cat([t] * batch)
t = t.to(device)
copy += [[t] + p[1:]]
return copy
def get_models_from_cond(cond, model_type):
models = []
for c in cond:
if model_type in c[1]:
models += [c[1][model_type]]
return models
def load_additional_models(positive, negative):
"""loads additional models in positive and negative conditioning"""
control_nets = get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
gligen = [x[1] for x in gligen]
models = control_nets + gligen
comfy.model_management.load_controlnet_gpu(models)
return models
def cleanup_additional_models(models):
"""cleanup additional models that were loaded"""
for m in models:
m.cleanup()
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False):
device = comfy.model_management.get_torch_device()
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
real_model = None
comfy.model_management.load_model_gpu(model)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
models = load_additional_models(positive, negative)
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar)
samples = samples.cpu()
cleanup_additional_models(models)
return samples
...@@ -23,8 +23,22 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -23,8 +23,22 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
adm_cond = cond[1]['adm_encoded'] adm_cond = cond[1]['adm_encoded']
input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
mult = torch.ones_like(input_x) * strength if 'mask' in cond[1]:
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if "mask_strength" in cond[1]:
mask_strength = cond[1]["mask_strength"]
mask = cond[1]['mask']
assert(mask.shape[1] == x_in.shape[2])
assert(mask.shape[2] == x_in.shape[3])
mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
else:
mask = torch.ones_like(input_x)
mult = mask * strength
if 'mask' not in cond[1]:
rr = 8 rr = 8
if area[2] != 0: if area[2] != 0:
for t in range(rr): for t in range(rr):
...@@ -38,6 +52,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con ...@@ -38,6 +52,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
if (area[1] + area[3]) < x_in.shape[3]: if (area[1] + area[3]) < x_in.shape[3]:
for t in range(rr): for t in range(rr):
mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1))
conditionning = {} conditionning = {}
conditionning['c_crossattn'] = cond[0] conditionning['c_crossattn'] = cond[0]
if cond_concat_in is not None and len(cond_concat_in) > 0: if cond_concat_in is not None and len(cond_concat_in) > 0:
...@@ -301,6 +316,71 @@ def blank_inpaint_image_like(latent_image): ...@@ -301,6 +316,71 @@ def blank_inpaint_image_like(latent_image):
blank_image[:,3] *= 0.1380 blank_image[:,3] *= 0.1380
return blank_image return blank_image
def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
b = masks.shape[0]
bounding_boxes = torch.zeros((b, 4), device=masks.device, dtype=torch.int)
is_empty = torch.zeros((b), device=masks.device, dtype=torch.bool)
for i in range(b):
mask = masks[i]
if mask.numel() == 0:
continue
if torch.max(mask != 0) == False:
is_empty[i] = True
continue
y, x = torch.where(mask)
bounding_boxes[i, 0] = torch.min(x)
bounding_boxes[i, 1] = torch.min(y)
bounding_boxes[i, 2] = torch.max(x)
bounding_boxes[i, 3] = torch.max(y)
return bounding_boxes, is_empty
def resolve_cond_masks(conditions, h, w, device):
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range(len(conditions)):
c = conditions[i]
if 'mask' in c[1]:
mask = c[1]['mask']
mask = mask.to(device=device)
modified = c[1].copy()
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if mask.shape[2] != h or mask.shape[3] != w:
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
if modified.get("set_area_to_bounds", False):
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
boxes, is_empty = get_mask_aabb(bounds)
if is_empty[0]:
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified['area'] = (8, 8, 0, 0)
else:
box = boxes[0]
H, W, Y, X = (box[3] - box[1] + 1, box[2] - box[0] + 1, box[1], box[0])
# Make sure the height and width are divisible by 8
if X % 8 != 0:
newx = X // 8 * 8
W = W + (X - newx)
X = newx
if Y % 8 != 0:
newy = Y // 8 * 8
H = H + (Y - newy)
Y = newy
if H % 8 != 0:
H = H + (8 - (H % 8))
if W % 8 != 0:
W = W + (8 - (W % 8))
area = (int(H), int(W), int(Y), int(X))
modified['area'] = area
modified['mask'] = mask
conditions[i] = [c[0], modified]
def create_cond_with_same_area_if_none(conds, c): def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c[1]: if 'area' not in c[1]:
return return
...@@ -429,7 +509,7 @@ class KSampler: ...@@ -429,7 +509,7 @@ class KSampler:
self.denoise = denoise self.denoise = denoise
self.model_options = model_options self.model_options = model_options
def _calculate_sigmas(self, steps): def calculate_sigmas(self, steps):
sigmas = None sigmas = None
discard_penultimate_sigma = False discard_penultimate_sigma = False
...@@ -438,13 +518,13 @@ class KSampler: ...@@ -438,13 +518,13 @@ class KSampler:
discard_penultimate_sigma = True discard_penultimate_sigma = True
if self.scheduler == "karras": if self.scheduler == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, device=self.device) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
elif self.scheduler == "normal": elif self.scheduler == "normal":
sigmas = self.model_wrap.get_sigmas(steps).to(self.device) sigmas = self.model_wrap.get_sigmas(steps)
elif self.scheduler == "simple": elif self.scheduler == "simple":
sigmas = simple_scheduler(self.model_wrap, steps).to(self.device) sigmas = simple_scheduler(self.model_wrap, steps)
elif self.scheduler == "ddim_uniform": elif self.scheduler == "ddim_uniform":
sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device) sigmas = ddim_scheduler(self.model_wrap, steps)
else: else:
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
...@@ -455,14 +535,14 @@ class KSampler: ...@@ -455,14 +535,14 @@ class KSampler:
def set_steps(self, steps, denoise=None): def set_steps(self, steps, denoise=None):
self.steps = steps self.steps = steps
if denoise is None or denoise > 0.9999: if denoise is None or denoise > 0.9999:
self.sigmas = self._calculate_sigmas(steps) self.sigmas = self.calculate_sigmas(steps).to(self.device)
else: else:
new_steps = int(steps/denoise) new_steps = int(steps/denoise)
sigmas = self._calculate_sigmas(new_steps) sigmas = self.calculate_sigmas(new_steps).to(self.device)
self.sigmas = sigmas[-(steps + 1):] self.sigmas = sigmas[-(steps + 1):]
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False):
def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): if sigmas is None:
sigmas = self.sigmas sigmas = self.sigmas
sigma_min = self.sigma_min sigma_min = self.sigma_min
...@@ -483,6 +563,10 @@ class KSampler: ...@@ -483,6 +563,10 @@ class KSampler:
positive = positive[:] positive = positive[:]
negative = negative[:] negative = negative[:]
resolve_cond_masks(positive, noise.shape[2], noise.shape[3], self.device)
resolve_cond_masks(negative, noise.shape[2], noise.shape[3], self.device)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
create_cond_with_same_area_if_none(negative, c) create_cond_with_same_area_if_none(negative, c)
...@@ -526,9 +610,9 @@ class KSampler: ...@@ -526,9 +610,9 @@ class KSampler:
with precision_scope(model_management.get_autocast_device(self.device)): with precision_scope(model_management.get_autocast_device(self.device)):
if self.sampler == "uni_pc": if self.sampler == "uni_pc":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
elif self.sampler == "uni_pc_bh2": elif self.sampler == "uni_pc_bh2":
samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim": elif self.sampler == "ddim":
timesteps = [] timesteps = []
for s in range(sigmas.shape[0]): for s in range(sigmas.shape[0]):
...@@ -536,6 +620,11 @@ class KSampler: ...@@ -536,6 +620,11 @@ class KSampler:
noise_mask = None noise_mask = None
if denoise_mask is not None: if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None)
sampler = DDIMSampler(self.model, device=self.device) sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
...@@ -549,11 +638,13 @@ class KSampler: ...@@ -549,11 +638,13 @@ class KSampler:
eta=0.0, eta=0.0,
x_T=z_enc, x_T=z_enc,
x0=latent_image, x0=latent_image,
img_callback=ddim_callback,
denoise_function=sampling_function, denoise_function=sampling_function,
extra_args=extra_args, extra_args=extra_args,
mask=noise_mask, mask=noise_mask,
to_zero=sigmas[-1]==0, to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1) end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
else: else:
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
...@@ -562,13 +653,17 @@ class KSampler: ...@@ -562,13 +653,17 @@ class KSampler:
noise = noise * sigmas[0] noise = noise * sigmas[0]
k_callback = None
if callback is not None:
k_callback = lambda x: callback(x["i"], x["denoised"], x["x"])
if latent_image is not None: if latent_image is not None:
noise += latent_image noise += latent_image
if self.sampler == "dpm_fast": if self.sampler == "dpm_fast":
samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
elif self.sampler == "dpm_adaptive": elif self.sampler == "dpm_adaptive":
samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
else: else:
samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args) samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
return samples.to(torch.float32) return samples.to(torch.float32)
...@@ -112,6 +112,8 @@ def load_lora(path, to_load): ...@@ -112,6 +112,8 @@ def load_lora(path, to_load):
loaded_keys.add(A_name) loaded_keys.add(A_name)
loaded_keys.add(B_name) loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x) hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x) hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x) hada_w2_a_name = "{}.hada_w2_a".format(x)
...@@ -133,6 +135,54 @@ def load_lora(path, to_load): ...@@ -133,6 +135,54 @@ def load_lora(path, to_load):
loaded_keys.add(hada_w2_a_name) loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name) loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)
...@@ -316,6 +366,33 @@ class ModelPatcher: ...@@ -316,6 +366,33 @@ class ModelPatcher:
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]] final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1) mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device) weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
elif len(v) == 8: #lokr
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(w1_a.float(), w1_b.float())
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(w2_a.float(), w2_b.float())
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha *= v[2] / dim
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
else: #loha else: #loha
w1a = v[0] w1a = v[0]
w1b = v[1] w1b = v[1]
......
...@@ -94,3 +94,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am ...@@ -94,3 +94,26 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
output[b:b+1] = out/out_div output[b:b+1] = out/out_div
return output return output
PROGRESS_BAR_HOOK = None
def set_progress_bar_global_hook(function):
global PROGRESS_BAR_HOOK
PROGRESS_BAR_HOOK = function
class ProgressBar:
def __init__(self, total):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
def update_absolute(self, value):
if value > self.total:
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total)
def update(self, value):
self.update_absolute(self.current + value)
...@@ -4,7 +4,10 @@ ...@@ -4,7 +4,10 @@
from __future__ import annotations from __future__ import annotations
from collections import OrderedDict from collections import OrderedDict
from typing import Literal try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch import torch
import torch.nn as nn import torch.nn as nn
......
...@@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength): ...@@ -10,7 +10,17 @@ def load_hypernetwork_patch(path, strength):
activate_output = sd.get('activate_output', False) activate_output = sd.get('activate_output', False)
last_layer_dropout = sd.get('last_layer_dropout', False) last_layer_dropout = sd.get('last_layer_dropout', False)
if activation_func != 'linear' or is_layer_norm != False or use_dropout != False or activate_output != False or last_layer_dropout != False: valid_activation = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
if activation_func not in valid_activation:
print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
return None return None
...@@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength): ...@@ -28,15 +38,27 @@ def load_hypernetwork_patch(path, strength):
keys = attn_weights.keys() keys = attn_weights.keys()
linears = filter(lambda a: a.endswith(".weight"), keys) linears = filter(lambda a: a.endswith(".weight"), keys)
linears = sorted(list(map(lambda a: a[:-len(".weight")], linears))) linears = list(map(lambda a: a[:-len(".weight")], linears))
layers = [] layers = []
for lin_name in linears: for i in range(len(linears)):
lin_name = linears[i]
last_layer = (i == (len(linears) - 1))
penultimate_layer = (i == (len(linears) - 2))
lin_weight = attn_weights['{}.weight'.format(lin_name)] lin_weight = attn_weights['{}.weight'.format(lin_name)]
lin_bias = attn_weights['{}.bias'.format(lin_name)] lin_bias = attn_weights['{}.bias'.format(lin_name)]
layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
layers += [layer] layers.append(layer)
if activation_func != "linear":
if (not last_layer) or (activate_output):
layers.append(valid_activation[activation_func]())
if is_layer_norm:
layers.append(torch.nn.LayerNorm(lin_weight.shape[0]))
if use_dropout:
if (not last_layer) and (not penultimate_layer or last_layer_dropout):
layers.append(torch.nn.Dropout(p=0.3))
output.append(torch.nn.Sequential(*layers)) output.append(torch.nn.Sequential(*layers))
out[dim] = torch.nn.ModuleList(output) out[dim] = torch.nn.ModuleList(output)
...@@ -71,7 +93,7 @@ class HypernetworkLoader: ...@@ -71,7 +93,7 @@ class HypernetworkLoader:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_hypernetwork" FUNCTION = "load_hypernetwork"
CATEGORY = "_for_testing" CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength): def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
......
...@@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da ...@@ -40,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da
input_data_all[x] = unique_id input_data_all[x] = unique_id
return input_data_all return input_data_all
def recursive_execute(server, prompt, outputs, current_item, extra_data={}): def recursive_execute(server, prompt, outputs, current_item, extra_data, executed):
unique_id = current_item unique_id = current_item
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if unique_id in outputs: if unique_id in outputs:
return [] return
executed = []
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
...@@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): ...@@ -57,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if input_unique_id not in outputs: if input_unique_id not in outputs:
executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed)
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
...@@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): ...@@ -72,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}):
server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id)
if "result" in outputs[unique_id]: if "result" in outputs[unique_id]:
outputs[unique_id] = outputs[unique_id]["result"] outputs[unique_id] = outputs[unique_id]["result"]
return executed + [unique_id] executed.add(unique_id)
def recursive_will_execute(prompt, outputs, current_item): def recursive_will_execute(prompt, outputs, current_item):
unique_id = current_item unique_id = current_item
...@@ -99,21 +97,25 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item ...@@ -99,21 +97,25 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
is_changed_old = '' is_changed_old = ''
is_changed = '' is_changed = ''
to_delete = False
if hasattr(class_def, 'IS_CHANGED'): if hasattr(class_def, 'IS_CHANGED'):
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
is_changed_old = old_prompt[unique_id]['is_changed'] is_changed_old = old_prompt[unique_id]['is_changed']
if 'is_changed' not in prompt[unique_id]: if 'is_changed' not in prompt[unique_id]:
input_data_all = get_input_data(inputs, class_def, unique_id, outputs) input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
if input_data_all is not None: if input_data_all is not None:
try:
is_changed = class_def.IS_CHANGED(**input_data_all) is_changed = class_def.IS_CHANGED(**input_data_all)
prompt[unique_id]['is_changed'] = is_changed prompt[unique_id]['is_changed'] = is_changed
except:
to_delete = True
else: else:
is_changed = prompt[unique_id]['is_changed'] is_changed = prompt[unique_id]['is_changed']
if unique_id not in outputs: if unique_id not in outputs:
return True return True
to_delete = False if not to_delete:
if is_changed != is_changed_old: if is_changed != is_changed_old:
to_delete = True to_delete = True
elif unique_id not in old_prompt: elif unique_id not in old_prompt:
...@@ -154,11 +156,20 @@ class PromptExecutor: ...@@ -154,11 +156,20 @@ class PromptExecutor:
self.server.client_id = None self.server.client_id = None
with torch.inference_mode(): with torch.inference_mode():
#delete cached outputs if nodes don't exist for them
to_delete = []
for o in self.outputs:
if o not in prompt:
to_delete += [o]
for o in to_delete:
d = self.outputs.pop(o)
del d
for x in prompt: for x in prompt:
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
current_outputs = set(self.outputs.keys()) current_outputs = set(self.outputs.keys())
executed = [] executed = set()
try: try:
to_execute = [] to_execute = []
for x in prompt: for x in prompt:
...@@ -181,12 +192,12 @@ class PromptExecutor: ...@@ -181,12 +192,12 @@ class PromptExecutor:
except: except:
valid = False valid = False
if valid: if valid:
executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
to_delete = [] to_delete = []
for o in self.outputs: for o in self.outputs:
if o not in current_outputs: if (o not in current_outputs) and (o not in executed):
to_delete += [o] to_delete += [o]
if o in self.old_prompt: if o in self.old_prompt:
d = self.old_prompt.pop(o) d = self.old_prompt.pop(o)
...@@ -194,11 +205,9 @@ class PromptExecutor: ...@@ -194,11 +205,9 @@ class PromptExecutor:
for o in to_delete: for o in to_delete:
d = self.outputs.pop(o) d = self.outputs.pop(o)
del d del d
else: finally:
executed = set(executed)
for x in executed: for x in executed:
self.old_prompt[x] = copy.deepcopy(prompt[x]) self.old_prompt[x] = copy.deepcopy(prompt[x])
finally:
self.server.last_node_id = None self.server.last_node_id = None
if self.server.client_id is not None: if self.server.client_id is not None:
self.server.send_sync("executing", { "node": None }, self.server.client_id) self.server.send_sync("executing", { "node": None }, self.server.client_id)
...@@ -249,6 +258,12 @@ def validate_inputs(prompt, item): ...@@ -249,6 +258,12 @@ def validate_inputs(prompt, item):
if "max" in info[1] and val > info[1]["max"]: if "max" in info[1] and val > info[1]["max"]:
return (False, "Value bigger than max. {}, {}".format(class_type, x)) return (False, "Value bigger than max. {}, {}".format(class_type, x))
if hasattr(obj_class, "VALIDATE_INPUTS"):
input_data_all = get_input_data(inputs, obj_class, unique_id)
ret = obj_class.VALIDATE_INPUTS(**input_data_all)
if ret != True:
return (False, "{}, {}".format(class_type, ret))
else:
if isinstance(type_input, list): if isinstance(type_input, list):
if val not in type_input: if val not in type_input:
return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input))
...@@ -273,7 +288,8 @@ def validate_prompt(prompt): ...@@ -273,7 +288,8 @@ def validate_prompt(prompt):
m = validate_inputs(prompt, o) m = validate_inputs(prompt, o)
valid = m[0] valid = m[0]
reason = m[1] reason = m[1]
except: except Exception as e:
print(traceback.format_exc())
valid = False valid = False
reason = "Parsing error" reason = "Parsing error"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment