"vscode:/vscode.git/clone" did not exist on "85fde89d7f5259aaec9e42cc8116c3dab70a69a4"
Commit 220afe33 authored by comfyanonymous's avatar comfyanonymous
Browse files

Initial commit.

parents
__pycache__/
*.py[cod]
output/
models/checkpoints
models/vae
GNU GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU General Public License is a free, copyleft license for
software and other kinds of works.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
the GNU General Public License is intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users. We, the Free Software Foundation, use the
GNU General Public License for most of our software; it applies also to
any other work released this way by its authors. You can apply it to
your programs, too.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
To protect your rights, we need to prevent others from denying you
these rights or asking you to surrender the rights. Therefore, you have
certain responsibilities if you distribute copies of the software, or if
you modify it: responsibilities to respect the freedom of others.
For example, if you distribute copies of such a program, whether
gratis or for a fee, you must pass on to the recipients the same
freedoms that you received. You must make sure that they, too, receive
or can get the source code. And you must show them these terms so they
know their rights.
Developers that use the GNU GPL protect your rights with two steps:
(1) assert copyright on the software, and (2) offer you this License
giving you legal permission to copy, distribute and/or modify it.
For the developers' and authors' protection, the GPL clearly explains
that there is no warranty for this free software. For both users' and
authors' sake, the GPL requires that modified versions be marked as
changed, so that their problems will not be attributed erroneously to
authors of previous versions.
Some devices are designed to deny users access to install or run
modified versions of the software inside them, although the manufacturer
can do so. This is fundamentally incompatible with the aim of
protecting users' freedom to change the software. The systematic
pattern of such abuse occurs in the area of products for individuals to
use, which is precisely where it is most unacceptable. Therefore, we
have designed this version of the GPL to prohibit the practice for those
products. If such problems arise substantially in other domains, we
stand ready to extend this provision to those domains in future versions
of the GPL, as needed to protect the freedom of users.
Finally, every program is threatened constantly by software patents.
States should not allow patents to restrict development and use of
software on general-purpose computers, but in those that do, we wish to
avoid the special danger that patents applied to a free program could
make it effectively proprietary. To prevent this, the GPL assures that
patents cannot be used to render the program non-free.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Use with the GNU Affero General Public License.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU Affero General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the special requirements of the GNU Affero General Public License,
section 13, concerning interaction through a network will apply to the
combination as such.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU General Public License from time to time. Such new versions will
be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If the program does terminal interaction, make it output a short
notice like this when it starts in an interactive mode:
<program> Copyright (C) <year> <name of author>
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
This is free software, and you are welcome to redistribute it
under certain conditions; type `show c' for details.
The hypothetical commands `show w' and `show c' should show the appropriate
parts of the General Public License. Of course, your program's commands
might be different; for a GUI interface, you would use an "about box".
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU GPL, see
<https://www.gnu.org/licenses/>.
The GNU General Public License does not permit incorporating your program
into proprietary programs. If your program is a subroutine library, you
may consider it more useful to permit linking proprietary applications with
the library. If this is what you want to do, use the GNU Lesser General
Public License instead of this License. But first, please read
<https://www.gnu.org/licenses/why-not-lgpl.html>.
ComfyUI
=======
A powerful and modular stable diffusion GUI.
-----------
![ComfyUI Screenshot](comfyui_screenshot.png)
This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface.
# Installing
Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
Put your VAE in: models/vae
At the time of writing this pytorch has issues with python versions higher than 3.10 so make sure your python/pip versions are 3.10.
### AMD
AMD users can install rocm and pytorch with pip if you don't have it already installed:
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2"```
### NVIDIA
Nvidia users should install Xformers.
### Dependencies
Install the dependencies:
```pip install -r requirements.txt```
# Running
```python main.py```
# Notes
Only parts of the graph that have an output with all the correct inputs will be executed.
Only parts of the graph that change from each execution to the next will be executed, if you submit the same graph twice only the first will be executed. If you change the last part of the graph only the part you changed and the part that depends on it will be executed.
Dragging a generated png on the webpage or loading one will give you the full workflow including seeds that were used to create it.
You can use () to change emphasis of a word or phrase like: (good code:1.2) or (bad code:0.8). The default emphasis for () is 1.1. To use () characters in your actual prompt escape them like \\( or \\).
### Fedora
To get python 3.10 on fedora:
```dnf install python3.10```
Then you can:
```python3.10 -m ensurepip```
This will let you use: pip3.10 to install all the dependencies.
# QA
### Why did you make this?
I wanted to learn how Stable Diffusion worked in detail. I also wanted something clean and powerful that would let me experiment with SD without restrictions.
### Who is this for?
This is for anyone that wants to make complex workflows with SD or that wants to learn more how SD works. The interface follows closely how SD works and the code should be much more simple to understand than other SD UIs.
from functools import reduce
import math
import operator
import numpy as np
from skimage import transform
import torch
from torch import nn
def translate2d(tx, ty):
mat = [[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def scale2d(sx, sy):
mat = [[sx, 0, 0],
[ 0, sy, 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
def rotate2d(theta):
mat = [[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)
class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans
def __call__(self, image):
h, w = image.size
mats = [translate2d(h / 2 - 0.5, w / 2 - 0.5)]
# x-flip
a0 = torch.randint(2, []).float()
mats.append(scale2d(1 - 2 * a0, 1))
# y-flip
do = (torch.rand([]) < self.a_prob).float()
a1 = torch.randint(2, []).float() * do
mats.append(scale2d(1, 1 - 2 * a1))
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = (torch.rand([]) * 2 * math.pi - math.pi) * do
mats.append(rotate2d(-a3))
# anisotropy
do = (torch.rand([]) < self.a_prob).float()
a4 = (torch.rand([]) * 2 * math.pi - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
a6 = torch.randn([]) * do
a7 = torch.randn([]) * do
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))
# form the transformation matrix and conditioning vector
mats.append(translate2d(-h / 2 + 0.5, -w / 2 + 0.5))
mat = reduce(operator.matmul, mats)
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])
# apply the transformation
image_orig = np.array(image, dtype=np.float32) / 255
if image_orig.ndim == 2:
image_orig = image_orig[..., None]
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(image_orig, tf.inverse, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image_orig = torch.as_tensor(image_orig).movedim(2, 0) * 2 - 1
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, image_orig, cond
class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
if mapping_cond is None:
mapping_cond = aug_cond
else:
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
def set_skip_stages(self, skip_stages):
return self.inner_model.set_skip_stages(skip_stages)
def set_patch_size(self, patch_size):
return self.inner_model.set_patch_size(patch_size)
from functools import partial
import json
import math
import warnings
from jsonmerge import merge
from . import augmentation, layers, models, utils
def load_config(file):
defaults = {
'model': {
'sigma_data': 1.,
'patch_size': 1,
'dropout_rate': 0.,
'augment_wrapper': True,
'augment_prob': 0.,
'mapping_cond_dim': 0,
'unet_cond_dim': 0,
'cross_cond_dim': 0,
'cross_attn_depths': None,
'skip_stages': 0,
'has_variance': False,
},
'dataset': {
'type': 'imagefolder',
},
'optimizer': {
'type': 'adamw',
'lr': 1e-4,
'betas': [0.95, 0.999],
'eps': 1e-6,
'weight_decay': 1e-3,
},
'lr_sched': {
'type': 'inverse',
'inv_gamma': 20000.,
'power': 1.,
'warmup': 0.99,
},
'ema_sched': {
'type': 'inverse',
'power': 0.6667,
'max_value': 0.9999
},
}
config = json.load(file)
return merge(defaults, config)
def make_model(config):
config = config['model']
assert config['type'] == 'image_v1'
model = models.ImageDenoiserModelV1(
config['input_channels'],
config['mapping_out'],
config['depths'],
config['channels'],
config['self_attn_depths'],
config['cross_attn_depths'],
patch_size=config['patch_size'],
dropout_rate=config['dropout_rate'],
mapping_cond_dim=config['mapping_cond_dim'] + (9 if config['augment_wrapper'] else 0),
unet_cond_dim=config['unet_cond_dim'],
cross_cond_dim=config['cross_cond_dim'],
skip_stages=config['skip_stages'],
has_variance=config['has_variance'],
)
if config['augment_wrapper']:
model = augmentation.KarrasAugmentWrapper(model)
return model
def make_denoiser_wrapper(config):
config = config['model']
sigma_data = config.get('sigma_data', 1.)
has_variance = config.get('has_variance', False)
if not has_variance:
return partial(layers.Denoiser, sigma_data=sigma_data)
return partial(layers.DenoiserWithVariance, sigma_data=sigma_data)
def make_sample_density(config):
sd_config = config['sigma_sample_density']
sigma_data = config['sigma_data']
if sd_config['type'] == 'lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale = sd_config['std'] if 'std' in sd_config else sd_config['scale']
return partial(utils.rand_log_normal, loc=loc, scale=scale)
if sd_config['type'] == 'loglogistic':
loc = sd_config['loc'] if 'loc' in sd_config else math.log(sigma_data)
scale = sd_config['scale'] if 'scale' in sd_config else 0.5
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_log_logistic, loc=loc, scale=scale, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'loguniform':
min_value = sd_config['min_value'] if 'min_value' in sd_config else config['sigma_min']
max_value = sd_config['max_value'] if 'max_value' in sd_config else config['sigma_max']
return partial(utils.rand_log_uniform, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'v-diffusion':
min_value = sd_config['min_value'] if 'min_value' in sd_config else 0.
max_value = sd_config['max_value'] if 'max_value' in sd_config else float('inf')
return partial(utils.rand_v_diffusion, sigma_data=sigma_data, min_value=min_value, max_value=max_value)
if sd_config['type'] == 'split-lognormal':
loc = sd_config['mean'] if 'mean' in sd_config else sd_config['loc']
scale_1 = sd_config['std_1'] if 'std_1' in sd_config else sd_config['scale_1']
scale_2 = sd_config['std_2'] if 'std_2' in sd_config else sd_config['scale_2']
return partial(utils.rand_split_log_normal, loc=loc, scale_1=scale_1, scale_2=scale_2)
raise ValueError('Unknown sample density type')
import math
import os
from pathlib import Path
from cleanfid.inception_torchscript import InceptionV3W
import clip
from resize_right import resize
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from tqdm.auto import trange
from . import utils
class InceptionV3FeatureExtractor(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
path = Path(os.environ.get('XDG_CACHE_HOME', Path.home() / '.cache')) / 'k-diffusion'
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
digest = 'f58cb9b6ec323ed63459aa4fb441fe750cfe39fafad6da5cb504a16f19e958f4'
utils.download_file(path / 'inception-2015-12-05.pt', url, digest)
self.model = InceptionV3W(str(path), resize_inside=False).to(device)
self.size = (299, 299)
def forward(self, x):
if x.shape[2:4] != self.size:
x = resize(x, out_shape=self.size, pad_mode='reflect')
if x.shape[1] == 1:
x = torch.cat([x] * 3, dim=1)
x = (x * 127.5 + 127.5).clamp(0, 255)
return self.model(x)
class CLIPFeatureExtractor(nn.Module):
def __init__(self, name='ViT-L/14@336px', device='cpu'):
super().__init__()
self.model = clip.load(name, device=device)[0].eval().requires_grad_(False)
self.normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
self.size = (self.model.visual.input_resolution, self.model.visual.input_resolution)
def forward(self, x):
if x.shape[2:4] != self.size:
x = resize(x.add(1).div(2), out_shape=self.size, pad_mode='reflect').clamp(0, 1)
x = self.normalize(x)
x = self.model.encode_image(x).float()
x = F.normalize(x) * x.shape[1] ** 0.5
return x
def compute_features(accelerator, sample_fn, extractor_fn, n, batch_size):
n_per_proc = math.ceil(n / accelerator.num_processes)
feats_all = []
try:
for i in trange(0, n_per_proc, batch_size, disable=not accelerator.is_main_process):
cur_batch_size = min(n - i, batch_size)
samples = sample_fn(cur_batch_size)[:cur_batch_size]
feats_all.append(accelerator.gather(extractor_fn(samples)))
except StopIteration:
pass
return torch.cat(feats_all)[:n]
def polynomial_kernel(x, y):
d = x.shape[-1]
dot = x @ y.transpose(-2, -1)
return (dot / d + 1) ** 3
def squared_mmd(x, y, kernel=polynomial_kernel):
m = x.shape[-2]
n = y.shape[-2]
kxx = kernel(x, x)
kyy = kernel(y, y)
kxy = kernel(x, y)
kxx_sum = kxx.sum([-1, -2]) - kxx.diagonal(dim1=-1, dim2=-2).sum(-1)
kyy_sum = kyy.sum([-1, -2]) - kyy.diagonal(dim1=-1, dim2=-2).sum(-1)
kxy_sum = kxy.sum([-1, -2])
term_1 = kxx_sum / m / (m - 1)
term_2 = kyy_sum / n / (n - 1)
term_3 = kxy_sum * 2 / m / n
return term_1 + term_2 - term_3
@utils.tf32_mode(matmul=False)
def kid(x, y, max_size=5000):
x_size, y_size = x.shape[0], y.shape[0]
n_partitions = math.ceil(max(x_size / max_size, y_size / max_size))
total_mmd = x.new_zeros([])
for i in range(n_partitions):
cur_x = x[round(i * x_size / n_partitions):round((i + 1) * x_size / n_partitions)]
cur_y = y[round(i * y_size / n_partitions):round((i + 1) * y_size / n_partitions)]
total_mmd = total_mmd + squared_mmd(cur_x, cur_y)
return total_mmd / n_partitions
class _MatrixSquareRootEig(torch.autograd.Function):
@staticmethod
def forward(ctx, a):
vals, vecs = torch.linalg.eigh(a)
ctx.save_for_backward(vals, vecs)
return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
@staticmethod
def backward(ctx, grad_output):
vals, vecs = ctx.saved_tensors
d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
vecs_t = vecs.transpose(-2, -1)
return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
def sqrtm_eig(a):
if a.ndim < 2:
raise RuntimeError('tensor of matrices must have at least 2 dimensions')
if a.shape[-2] != a.shape[-1]:
raise RuntimeError('tensor must be batches of square matrices')
return _MatrixSquareRootEig.apply(a)
@utils.tf32_mode(matmul=False)
def fid(x, y, eps=1e-8):
x_mean = x.mean(dim=0)
y_mean = y.mean(dim=0)
mean_term = (x_mean - y_mean).pow(2).sum()
x_cov = torch.cov(x.T)
y_cov = torch.cov(y.T)
eps_eye = torch.eye(x_cov.shape[0], device=x_cov.device, dtype=x_cov.dtype) * eps
x_cov = x_cov + eps_eye
y_cov = y_cov + eps_eye
x_cov_sqrt = sqrtm_eig(x_cov)
cov_term = torch.trace(x_cov + y_cov - 2 * sqrtm_eig(x_cov_sqrt @ y_cov @ x_cov_sqrt))
return mean_term + cov_term
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
if quantize:
return dists.abs().argmin(dim=0).view(sigma.shape)
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t-low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)
import torch
from torch import nn
class DDPGradientStatsHook:
def __init__(self, ddp_module):
try:
ddp_module.register_comm_hook(self, self._hook_fn)
except AttributeError:
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules')
self._clear_state()
def _clear_state(self):
self.bucket_sq_norms_small_batch = []
self.bucket_sq_norms_large_batch = []
@staticmethod
def _hook_fn(self, bucket):
buf = bucket.buffer()
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum())
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future()
def callback(fut):
buf = fut.value()[0]
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum())
return buf
return fut.then(callback)
def get_stats(self):
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch)
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch)
self._clear_state()
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch])
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG)
return stats[0].item(), stats[1].item()
class GradientNoiseScale:
"""Calculates the gradient noise scale (1 / SNR), or critical batch size,
from _An Empirical Model of Large-Batch Training_,
https://arxiv.org/abs/1812.06162).
Args:
beta (float): The decay factor for the exponential moving averages used to
calculate the gradient noise scale.
Default: 0.9998
eps (float): Added for numerical stability.
Default: 1e-8
"""
def __init__(self, beta=0.9998, eps=1e-8):
self.beta = beta
self.eps = eps
self.ema_sq_norm = 0.
self.ema_var = 0.
self.beta_cumprod = 1.
self.gradient_noise_scale = float('nan')
def state_dict(self):
"""Returns the state of the object as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the object's state.
Args:
state_dict (dict): object state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch):
"""Updates the state with a new batch's gradient statistics, and returns the
current gradient noise scale.
Args:
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or
per sample gradients.
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or
per sample gradients.
n_small_batch (int): The batch size of the individual microbatch or per sample
gradients (1 if per sample).
n_large_batch (int): The total batch size of the mean of the microbatch or
per sample gradients.
"""
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch)
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch)
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var
self.beta_cumprod *= self.beta
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps)
return self.gradient_noise_scale
def get_gns(self):
"""Returns the current gradient noise scale."""
return self.gradient_noise_scale
def get_stats(self):
"""Returns the current (debiased) estimates of the squared mean gradient
and gradient variance."""
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod)
import math
from einops import rearrange, repeat
import torch
from torch import nn
from torch.nn import functional as F
from . import utils
# Karras et al. preconditioned denoiser
class Denoiser(nn.Module):
"""A Karras et al. preconditioner for denoising diffusion models."""
def __init__(self, inner_model, sigma_data=1.):
super().__init__()
self.inner_model = inner_model
self.sigma_data = sigma_data
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, sigma, **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, sigma, **kwargs) * c_out + input * c_skip
class DenoiserWithVariance(Denoiser):
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output, logvar = self.inner_model(noised_input * c_in, sigma, return_variance=True, **kwargs)
logvar = utils.append_dims(logvar, model_output.ndim)
target = (input - c_skip * noised_input) / c_out
losses = ((model_output - target) ** 2 / logvar.exp() + logvar) / 2
return losses.flatten(1).mean(1)
# Residual blocks
class ResidualBlock(nn.Module):
def __init__(self, *main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
# Noise level (and other) conditioning
class ConditionedModule(nn.Module):
pass
class UnconditionedModule(ConditionedModule):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, input, cond=None):
return self.module(input)
class ConditionedSequential(nn.Sequential, ConditionedModule):
def forward(self, input, cond):
for module in self:
if isinstance(module, ConditionedModule):
input = module(input, cond)
else:
input = module(input)
return input
class ConditionedResidualBlock(ConditionedModule):
def __init__(self, *main, skip=None):
super().__init__()
self.main = ConditionedSequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input, cond):
skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input)
return self.main(input, cond) + skip
class AdaGN(ConditionedModule):
def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'):
super().__init__()
self.num_groups = num_groups
self.eps = eps
self.cond_key = cond_key
self.mapper = nn.Linear(feats_in, c_out * 2)
def forward(self, input, cond):
weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1)
input = F.group_norm(input, self.num_groups, eps=self.eps)
return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1)
# Attention
class SelfAttention2d(ConditionedModule):
def __init__(self, c_in, n_head, norm, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm_in = norm(c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv2d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv2d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
qkv = self.qkv_proj(self.norm_in(input, cond))
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, h * w]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, h, w])
return input + self.out_proj(y)
class CrossAttention2d(ConditionedModule):
def __init__(self, c_dec, c_enc, n_head, norm_dec, dropout_rate=0.,
cond_key='cross', cond_key_padding='cross_padding'):
super().__init__()
assert c_dec % n_head == 0
self.cond_key = cond_key
self.cond_key_padding = cond_key_padding
self.norm_enc = nn.LayerNorm(c_enc)
self.norm_dec = norm_dec(c_dec)
self.n_head = n_head
self.q_proj = nn.Conv2d(c_dec, c_dec, 1)
self.kv_proj = nn.Linear(c_enc, c_dec * 2)
self.out_proj = nn.Conv2d(c_dec, c_dec, 1)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input, cond):
n, c, h, w = input.shape
q = self.q_proj(self.norm_dec(input, cond))
q = q.view([n, self.n_head, c // self.n_head, h * w]).transpose(2, 3)
kv = self.kv_proj(self.norm_enc(cond[self.cond_key]))
kv = kv.view([n, -1, self.n_head * 2, c // self.n_head]).transpose(1, 2)
k, v = kv.chunk(2, dim=1)
scale = k.shape[3] ** -0.25
att = ((q * scale) @ (k.transpose(2, 3) * scale))
att = att - (cond[self.cond_key_padding][:, None, None, :]) * 10000
att = att.softmax(3)
att = self.dropout(att)
y = (att @ v).transpose(2, 3)
y = y.contiguous().view([n, c, h, w])
return input + self.out_proj(y)
# Downsampling/upsampling
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
_kernels['bilinear'] = _kernels['linear']
_kernels['bicubic'] = _kernels['cubic']
class Downsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]])
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv2d(x, weight, stride=2)
class Upsample2d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect'):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor([_kernels[kernel]]) * 2
self.pad = kernel_1d.shape[1] // 2 - 1
self.register_buffer('kernel', kernel_1d.T @ kernel_1d)
def forward(self, x):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
# Embeddings
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
# U-Nets
class UNet(ConditionedModule):
def __init__(self, d_blocks, u_blocks, skip_stages=0):
super().__init__()
self.d_blocks = nn.ModuleList(d_blocks)
self.u_blocks = nn.ModuleList(u_blocks)
self.skip_stages = skip_stages
def forward(self, input, cond):
skips = []
for block in self.d_blocks[self.skip_stages:]:
input = block(input, cond)
skips.append(input)
for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))):
input = block(input, cond, skip if i > 0 else None)
return input
from .image_v1 import ImageDenoiserModelV1
import math
import torch
from torch import nn
from torch.nn import functional as F
from .. import layers, utils
def orthogonal_(module):
nn.init.orthogonal_(module.weight)
return module
class ResConvBlock(layers.ConditionedResidualBlock):
def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.):
skip = None if c_in == c_out else orthogonal_(nn.Conv2d(c_in, c_out, 1, bias=False))
super().__init__(
layers.AdaGN(feats_in, c_in, max(1, c_in // group_size)),
nn.GELU(),
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
layers.AdaGN(feats_in, c_mid, max(1, c_mid // group_size)),
nn.GELU(),
nn.Conv2d(c_mid, c_out, 3, padding=1),
nn.Dropout2d(dropout_rate, inplace=True),
skip=skip)
class DBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = [nn.Identity()]
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
super().__init__(*modules)
self.set_downsample(downsample)
def set_downsample(self, downsample):
self[0] = layers.Downsample2d() if downsample else nn.Identity()
return self
class UBlock(layers.ConditionedSequential):
def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample=False, self_attn=False, cross_attn=False, c_enc=0):
modules = []
for i in range(n_layers):
my_c_in = c_in if i == 0 else c_mid
my_c_out = c_mid if i < n_layers - 1 else c_out
modules.append(ResConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate))
if self_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.SelfAttention2d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate))
if cross_attn:
norm = lambda c_in: layers.AdaGN(feats_in, c_in, max(1, my_c_out // group_size))
modules.append(layers.CrossAttention2d(my_c_out, c_enc, max(1, my_c_out // head_size), norm, dropout_rate))
modules.append(nn.Identity())
super().__init__(*modules)
self.set_upsample(upsample)
def forward(self, input, cond, skip=None):
if skip is not None:
input = torch.cat([input, skip], dim=1)
return super().forward(input, cond)
def set_upsample(self, upsample):
self[-1] = layers.Upsample2d() if upsample else nn.Identity()
return self
class MappingNet(nn.Sequential):
def __init__(self, feats_in, feats_out, n_layers=2):
layers = []
for i in range(n_layers):
layers.append(orthogonal_(nn.Linear(feats_in if i == 0 else feats_out, feats_out)))
layers.append(nn.GELU())
super().__init__(*layers)
class ImageDenoiserModelV1(nn.Module):
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, cross_attn_depths=None, mapping_cond_dim=0, unet_cond_dim=0, cross_cond_dim=0, dropout_rate=0., patch_size=1, skip_stages=0, has_variance=False):
super().__init__()
self.c_in = c_in
self.channels = channels
self.unet_cond_dim = unet_cond_dim
self.patch_size = patch_size
self.has_variance = has_variance
self.timestep_embed = layers.FourierFeatures(1, feats_in)
if mapping_cond_dim > 0:
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
self.mapping = MappingNet(feats_in, feats_in)
self.proj_in = nn.Conv2d((c_in + unet_cond_dim) * self.patch_size ** 2, channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(channels[max(0, skip_stages - 1)], c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
if cross_cond_dim == 0:
cross_attn_depths = [False] * len(self_attn_depths)
d_blocks, u_blocks = [], []
for i in range(len(depths)):
my_c_in = channels[max(0, i - 1)]
d_blocks.append(DBlock(depths[i], feats_in, my_c_in, channels[i], channels[i], downsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
for i in range(len(depths)):
my_c_in = channels[i] * 2 if i < len(depths) - 1 else channels[i]
my_c_out = channels[max(0, i - 1)]
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > skip_stages, self_attn=self_attn_depths[i], cross_attn=cross_attn_depths[i], c_enc=cross_cond_dim, dropout_rate=dropout_rate))
self.u_net = layers.UNet(d_blocks, reversed(u_blocks), skip_stages=skip_stages)
def forward(self, input, sigma, mapping_cond=None, unet_cond=None, cross_cond=None, cross_cond_padding=None, return_variance=False):
c_noise = sigma.log() / 4
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
cond = {'cond': mapping_out}
if unet_cond is not None:
input = torch.cat([input, unet_cond], dim=1)
if cross_cond is not None:
cond['cross'] = cross_cond
cond['cross_padding'] = cross_cond_padding
if self.patch_size > 1:
input = F.pixel_unshuffle(input, self.patch_size)
input = self.proj_in(input)
input = self.u_net(input, cond)
input = self.proj_out(input)
if self.has_variance:
input, logvar = input[:, :-1], input[:, -1].flatten(1).mean(1)
if self.patch_size > 1:
input = F.pixel_shuffle(input, self.patch_size)
if self.has_variance and return_variance:
return input, logvar
return input
def set_skip_stages(self, skip_stages):
self.proj_in = nn.Conv2d(self.proj_in.in_channels, self.channels[max(0, skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, skip_stages - 1)], self.proj_out.out_channels, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
self.u_net.skip_stages = skip_stages
for i, block in enumerate(self.u_net.d_blocks):
block.set_downsample(i > skip_stages)
for i, block in enumerate(reversed(self.u_net.u_blocks)):
block.set_upsample(i > skip_stages)
return self
def set_patch_size(self, patch_size):
self.patch_size = patch_size
self.proj_in = nn.Conv2d((self.c_in + self.unet_cond_dim) * self.patch_size ** 2, self.channels[max(0, self.u_net.skip_stages - 1)], 1)
self.proj_out = nn.Conv2d(self.channels[max(0, self.u_net.skip_stages - 1)], self.c_in * self.patch_size ** 2 + (1 if self.has_variance else 0), 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
import math
from scipy import integrate
import torch
from torch import nn
from torchdiffeq import odeint
import torchsde
from tqdm.auto import trange, tqdm
from . import utils
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
"""Constructs an polynomial in log sigma noise schedule."""
ramp = torch.linspace(1, 0, n, device=device) ** rho
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
return append_zero(sigmas)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to, eta=1.):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
if not eta:
return sigma_to, 0.
sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
def default_noise_sampler(x):
return lambda sigma, sigma_next: torch.randn_like(x)
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will
use one BrownianTree per batch item, each with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
dt = sigmas[i + 1] - sigma_hat
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
if sigma_down == 0:
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2
sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}
class PIDStepSizeController:
"""A PID controller for ODE adaptive step size control."""
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
self.h = h
self.b1 = (pcoeff + icoeff + dcoeff) / order
self.b2 = -(pcoeff + 2 * dcoeff) / order
self.b3 = dcoeff / order
self.accept_safety = accept_safety
self.eps = eps
self.errs = []
def limiter(self, x):
return 1 + math.atan(x - 1)
def propose_step(self, error):
inv_error = 1 / (float(error) + self.eps)
if not self.errs:
self.errs = [inv_error, inv_error, inv_error]
self.errs[0] = inv_error
factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
factor = self.limiter(factor)
accept = factor >= self.accept_safety
if accept:
self.errs[2] = self.errs[1]
self.errs[1] = self.errs[0]
self.h *= factor
return accept
class DPMSolver(nn.Module):
"""DPM-Solver. See https://arxiv.org/abs/2206.00927."""
def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
super().__init__()
self.model = model
self.extra_args = {} if extra_args is None else extra_args
self.eps_callback = eps_callback
self.info_callback = info_callback
def t(self, sigma):
return -sigma.log()
def sigma(self, t):
return t.neg().exp()
def eps(self, eps_cache, key, x, t, *args, **kwargs):
if key in eps_cache:
return eps_cache[key], eps_cache
sigma = self.sigma(t) * x.new_ones([x.shape[0]])
eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
if self.eps_callback is not None:
self.eps_callback()
return eps, {key: eps, **eps_cache}
def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
x_1 = x - self.sigma(t_next) * h.expm1() * eps
return x_1, eps_cache
def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
s1 = t + r1 * h
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
return x_2, eps_cache
def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
eps_cache = {} if eps_cache is None else eps_cache
h = t_next - t
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
s1 = t + r1 * h
s2 = t + r2 * h
u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
return x_3, eps_cache
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
if not t_end > t_start and eta:
raise ValueError('eta must be 0 for reverse sampling')
m = math.floor(nfe / 3) + 1
ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
if nfe % 3 == 0:
orders = [3] * (m - 2) + [2, 1]
else:
orders = [3] * (m - 1) + [nfe % 3]
for i in range(len(orders)):
eps_cache = {}
t, t_next = ts[i], ts[i + 1]
if eta:
sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
t_next_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
else:
t_next_, su = t_next, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
denoised = x - self.sigma(t) * eps
if self.info_callback is not None:
self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
if orders[i] == 1:
x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
elif orders[i] == 2:
x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
else:
x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
return x
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
forward = t_end > t_start
if not forward and eta:
raise ValueError('eta must be 0 for reverse sampling')
h_init = abs(h_init) * (1 if forward else -1)
atol = torch.tensor(atol)
rtol = torch.tensor(rtol)
s = t_start
x_prev = x
accept = True
pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
eps_cache = {}
t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
if eta:
sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
t_ = torch.minimum(t_end, self.t(sd))
su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
else:
t_, su = t, 0.
eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
denoised = x - self.sigma(s) * eps
if order == 2:
x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
else:
x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
accept = pid.propose_step(error)
if accept:
x_prev = x_low
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
s = t
info['n_accept'] += 1
else:
info['n_reject'] += 1
info['nfe'] += order
info['steps'] += 1
if self.info_callback is not None:
self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
return x, info
@torch.no_grad()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(total=n, disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
@torch.no_grad()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
if return_info:
return x, info
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
if sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigmas[i + 1] - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
s = t + h * r
fac = 1 / (2 * r)
# Step 1
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
s_ = t_fn(sd)
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
# Step 2
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
t_next_ = t_fn(sd)
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
return x
@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x
from contextlib import contextmanager
import hashlib
import math
from pathlib import Path
import shutil
import urllib
import warnings
from PIL import Image
import torch
from torch import nn, optim
from torch.utils import data
from torchvision.transforms import functional as TF
def from_pil_image(x):
"""Converts from a PIL image to a tensor."""
x = TF.to_tensor(x)
if x.ndim == 2:
x = x[..., None]
return x * 2 - 1
def to_pil_image(x):
"""Converts from a tensor to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(image.convert(mode)) for image in examples[image_key]]
return {image_key: images}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
expanded = x[(...,) + (None,) * dims_to_append]
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
# https://github.com/pytorch/pytorch/issues/84364
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
def n_params(module):
"""Returns the number of trainable parameters in a module."""
return sum(p.numel() for p in module.parameters())
def download_file(path, url, digest=None):
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
shutil.copyfileobj(response, f)
if digest is not None:
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if digest != file_digest:
raise OSError(f'hash of {path} (url: {url}) failed to validate')
return path
@contextmanager
def train_mode(model, mode=True):
"""A context manager that places a model into training mode and restores
the previous mode on exit."""
modes = [module.training for module in model.modules()]
try:
yield model.train(mode)
finally:
for i, module in enumerate(model.modules()):
module.training = modes[i]
def eval_mode(model):
"""A context manager that places a model into evaluation mode and restores
the previous mode on exit."""
return train_mode(model, False)
@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
version of a model. It should be called after each optimizer step."""
model_params = dict(model.named_parameters())
averaged_params = dict(averaged_model.named_parameters())
assert model_params.keys() == averaged_params.keys()
for name, param in model_params.items():
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
model_buffers = dict(model.named_buffers())
averaged_buffers = dict(averaged_model.named_buffers())
assert model_buffers.keys() == averaged_buffers.keys()
for name, buf in model_buffers.items():
averaged_buffers[name].copy_(buf)
class EMAWarmup:
"""Implements an EMA warmup using an inverse decay schedule.
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
max_value (float): The maximum EMA decay rate. Default: 1.
start_at (int): The epoch to start averaging at. Default: 0.
last_epoch (int): The index of last epoch. Default: 0.
"""
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
last_epoch=0):
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.start_at = start_at
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the class as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the class's state.
Args:
state_dict (dict): scaler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_value(self):
"""Gets the current EMA decay rate."""
epoch = max(0, self.last_epoch - self.start_at)
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
def step(self):
"""Updates the step count."""
self.last_epoch += 1
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
inv_gamma is the number of steps/epochs required for the learning rate to decay to
(1 / 2)**power of its original value.
Args:
optimizer (Optimizer): Wrapped optimizer.
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
power (float): Exponential factor of learning rate decay. Default: 1.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.inv_gamma = inv_gamma
self.power = power
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
class ExponentialLR(optim.lr_scheduler._LRScheduler):
"""Implements an exponential learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
continuously by decay (default 0.5) every num_steps steps.
Args:
optimizer (Optimizer): Wrapped optimizer.
num_steps (float): The number of steps to decay the learning rate by decay in.
decay (float): The factor by which to decay the learning rate every num_steps
steps. Default: 0.5.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.num_steps = num_steps
self.decay = decay
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
"""Draws samples from an lognormal distribution."""
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from an optionally truncated log-logistic distribution."""
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
return u.logit().mul(scale).add(loc).exp().to(dtype)
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
"""Draws samples from an log-uniform distribution."""
min_value = math.log(min_value)
max_value = math.log(max_value)
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from a truncated v-diffusion training timestep distribution."""
min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
return torch.tan(u * math.pi / 2) * sigma_data
def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
"""Draws samples from a split lognormal distribution."""
n = torch.randn(shape, device=device, dtype=dtype).abs()
u = torch.rand(shape, device=device, dtype=dtype)
n_left = n * -scale_1 + loc
n_right = n * scale_2 + loc
ratio = scale_1 / (scale_1 + scale_2)
return torch.where(u < ratio, n_left, n_right).exp()
class FolderOfImages(data.Dataset):
"""Recursively finds all images in a directory. It does not support
classes/targets."""
IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
def __init__(self, root, transform=None):
super().__init__()
self.root = Path(root)
self.transform = nn.Identity() if transform is None else transform
self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
def __repr__(self):
return f'FolderOfImages(root="{self.root}", len: {len(self)})'
def __len__(self):
return len(self.paths)
def __getitem__(self, key):
path = self.paths[key]
with open(path, 'rb') as f:
image = Image.open(f).convert('RGB')
image = self.transform(image)
return image,
class CSVLogger:
def __init__(self, filename, columns):
self.filename = Path(filename)
self.columns = columns
if self.filename.exists():
self.file = open(self.filename, 'a')
else:
self.file = open(self.filename, 'w')
self.write(*self.columns)
def write(self, *args):
print(*args, sep=',', file=self.file, flush=True)
@contextmanager
def tf32_mode(cudnn=None, matmul=None):
"""A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
cudnn_old = torch.backends.cudnn.allow_tf32
matmul_old = torch.backends.cuda.matmul.allow_tf32
try:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul
yield
finally:
if cudnn is not None:
torch.backends.cudnn.allow_tf32 = cudnn_old
if matmul is not None:
torch.backends.cuda.matmul.allow_tf32 = matmul_old
import torch
from ldm.modules.midas.api import load_midas_transform
class AddMiDaS(object):
def __init__(self, model_type):
super().__init__()
self.transform = load_midas_transform(model_type)
def pt2np(self, x):
x = ((x + 1.0) * .5).detach().cpu().numpy()
return x
def np2pt(self, x):
x = torch.from_numpy(x) * 2 - 1.
return x
def __call__(self, sample):
# sample['jpg'] is tensor hwc in [-1, 1] at this point
x = self.pt2np(sample['jpg'])
x = self.transform({"image": x})["image"]
sample['midas_in'] = x
return sample
\ No newline at end of file
import torch
# import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
from ldm.modules.ema import LitEma
# class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(torch.nn.Module):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
if path.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(path, device="cpu")
else:
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val"+postfix)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list,
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
\ No newline at end of file
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import numpy as np
# import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext
from functools import partial
import itertools
from tqdm import tqdm
from torchvision.utils import make_grid
# from pytorch_lightning.utilities.distributed import rank_zero_only
from omegaconf import ListConfig
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
# class DDPM(pl.LightningModule):
class DDPM(torch.nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
make_it_fit=False,
ucg_training=None,
reset_ema=False,
reset_num_ema_updates=False,
):
super().__init__()
assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
self.make_it_fit = make_it_fit
if reset_ema: assert exists(ckpt_path)
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
if reset_ema:
assert self.use_ema
print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
self.ucg_training = ucg_training or dict()
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
if self.make_it_fit:
n_params = len([name for name, _ in
itertools.chain(self.named_parameters(),
self.named_buffers())])
for name, param in tqdm(
itertools.chain(self.named_parameters(),
self.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape) == len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
for i in range(new_param.shape[0]):
new_param[i] = old_param[i % old_shape[0]]
elif len(new_shape) >= 2:
for i in range(new_param.shape[0]):
for j in range(new_param.shape[1]):
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
n_used_old = torch.ones(old_shape[1])
for j in range(new_param.shape[1]):
n_used_old[j % old_shape[1]] += 1
n_used_new = torch.zeros(new_shape[1])
for j in range(new_param.shape[1]):
n_used_new[j] = n_used_old[j % old_shape[1]]
n_used_new = n_used_new[None, :]
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
new_param /= n_used_new
sd[name] = new_param
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys:\n {missing}")
if len(unexpected) > 0:
print(f"\nUnexpected Keys:\n {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
for k in self.ucg_training:
p = self.ucg_training[k]["p"]
val = self.ucg_training[k]["val"]
if val is None:
val = ""
for i in range(len(batch[k])):
if self.ucg_prng.choice(2, p=[1 - p, p]):
batch[k][i] = val
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
force_null_conditioning=False,
*args, **kwargs):
self.force_null_conditioning = force_null_conditioning
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
reset_ema = kwargs.pop("reset_ema", False)
reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
# self.instantiate_first_stage(first_stage_config)
# self.instantiate_cond_stage(cond_stage_config)
self.first_stage_config = first_stage_config
self.cond_stage_config = cond_stage_config
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
if reset_ema:
assert self.use_ema
print(
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
if reset_num_ema_updates:
print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
# @rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"], )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1, padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1, padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None, return_x=False):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None and not self.force_null_conditioning:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox', "txt"]:
xc = batch[cond_key]
elif cond_key in ['class_label', 'cls']:
xc = batch
else:
xc = super().get_input(batch, cond_key).to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
ckey = __conditioning_keys__[self.model.conditioning_key]
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {'pos_x': pos_x, 'pos_y': pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_x:
out.extend([x])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is expected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None, **kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
shape, cond, verbose=False, **kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True, **kwargs)
return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
if self.cond_stage_key in ["class_label", "cls"]:
xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
return self.get_learned_conditioning(xc)
else:
raise NotImplementedError("todo")
if isinstance(c, list): # in case the encoder gives us a list
for i in range(len(c)):
c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
else:
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', "cls"]:
try:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
except KeyError:
# probably no "human_label" in batch
pass
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
if self.model.conditioning_key == "crossattn-adm":
uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
mask = 1. - mask
with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
# class DiffusionWrapper(pl.LightningModule):
class DiffusionWrapper(torch.nn.Module):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
if not self.sequential_cross_attn:
cc = torch.cat(c_crossattn, 1)
else:
cc = c_crossattn
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'hybrid-adm':
assert c_adm is not None
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc, y=c_adm)
elif self.conditioning_key == 'crossattn-adm':
assert c_adm is not None
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc, y=c_adm)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
else:
raise NotImplementedError()
return out
class LatentUpscaleDiffusion(LatentDiffusion):
def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
assert not self.cond_stage_trainable
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
self.noise_level_key = noise_level_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low)
if self.noise_level_key is not None:
# get noise level from batch instead, e.g. when extracting a custom noise level for bsr
raise NotImplementedError('TODO')
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
if log_mode:
# TODO: maybe disable if too expensive
x_low_rec = self.low_scale_model.decode(zx)
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', 'cls']:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# TODO explore better "unconditional" choices for the other keys
# maybe guide away from empty text label and highest noise level and maximally degraded zx?
uc = dict()
for k in c:
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif k == "c_adm": # todo: only run with text-based guidance?
assert isinstance(c[k], torch.Tensor)
#uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
uc[k] = c[k]
elif isinstance(c[k], list):
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
return log
class LatentFinetuneDiffusion(LatentDiffusion):
"""
Basis for different finetunas, such as inpainting or depth2image
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
concat_keys: tuple,
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight"
),
keep_finetune_dims=4,
# if model was trained without concat mode before and we would like to keep these channels
c_concat_log_start=None, # to log reconstruction of c_concat codes
c_concat_log_end=None,
*args, **kwargs
):
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", list())
super().__init__(*args, **kwargs)
self.finetune_keys = finetune_keys
self.concat_keys = concat_keys
self.keep_dims = keep_finetune_dims
self.c_concat_log_start = c_concat_log_start
self.c_concat_log_end = c_concat_log_end
if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
if exists(ckpt_path):
self.init_from_ckpt(ckpt_path, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
# make it explicit, finetune by including extra input channels
if exists(self.finetune_keys) and k in self.finetune_keys:
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
print(
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
new_entry = torch.zeros_like(param) # zero init
assert exists(new_entry), 'did not find matching parameter to modify'
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ['class_label', 'cls']:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
class LatentInpaintDiffusion(LatentFinetuneDiffusion):
"""
can either run as pure inpainting model (only concat mode) or with mixed conditionings,
e.g. mask as concat and text via cross-attn.
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args, **kwargs
):
super().__init__(concat_keys, *args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
log["masked_image"] = rearrange(args[0]["masked_image"],
'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
return log
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
"""
condition on monocular depth estimation
"""
def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.depth_model = instantiate_from_config(depth_stage_config)
self.depth_stage_key = concat_keys[0]
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
assert len(self.concat_keys) == 1
c_cat = list()
for ck in self.concat_keys:
cc = batch[ck]
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
cc = self.depth_model(cc)
cc = torch.nn.functional.interpolate(
cc,
size=z.shape[2:],
mode="bicubic",
align_corners=False,
)
depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
keepdim=True)
cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super().log_images(*args, **kwargs)
depth = self.depth_model(args[0][self.depth_stage_key])
depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
torch.amax(depth, dim=[1, 2, 3], keepdim=True)
log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
return log
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
"""
condition on low-res image (and optionally on some spatial noise augmentation)
"""
def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
low_scale_config=None, low_scale_key=None, *args, **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.reshuffle_patch_size = reshuffle_patch_size
self.low_scale_model = None
if low_scale_config is not None:
print("Initializing a low-scale model")
assert exists(low_scale_key)
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
assert len(self.concat_keys) == 1
# optionally make spatial noise_level here
c_cat = list()
noise_level = None
for ck in self.concat_keys:
cc = batch[ck]
cc = rearrange(cc, 'b h w c -> b c h w')
if exists(self.reshuffle_patch_size):
assert isinstance(self.reshuffle_patch_size, int)
cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
if exists(self.low_scale_model) and ck == self.low_scale_key:
cc, noise_level = self.low_scale_model(cc)
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
if exists(noise_level):
all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
else:
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
return log
from .sampler import DPMSolverSampler
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment