Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1cf6d92d
Unverified
Commit
1cf6d92d
authored
Dec 23, 2022
by
BlueRum
Committed by
GitHub
Dec 23, 2022
Browse files
[exmaple] diffuser, support quant inference for stable diffusion (#2186)
parent
bc0e271e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
4 deletions
+116
-4
examples/images/diffusion/scripts/img2img.py
examples/images/diffusion/scripts/img2img.py
+15
-1
examples/images/diffusion/scripts/txt2img.py
examples/images/diffusion/scripts/txt2img.py
+18
-3
examples/images/diffusion/scripts/utils.py
examples/images/diffusion/scripts/utils.py
+83
-0
No files found.
examples/images/diffusion/scripts/img2img.py
View file @
1cf6d92d
...
@@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder
...
@@ -22,6 +22,7 @@ from imwatermark import WatermarkEncoder
from
scripts.txt2img
import
put_watermark
from
scripts.txt2img
import
put_watermark
from
ldm.util
import
instantiate_from_config
from
ldm.util
import
instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
utils
import
replace_module
,
getModelSize
def
chunk
(
it
,
size
):
def
chunk
(
it
,
size
):
...
@@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False):
...
@@ -44,7 +45,6 @@ def load_model_from_config(config, ckpt, verbose=False):
print
(
"unexpected keys:"
)
print
(
"unexpected keys:"
)
print
(
u
)
print
(
u
)
model
.
cuda
()
model
.
eval
()
model
.
eval
()
return
model
return
model
...
@@ -183,6 +183,12 @@ def main():
...
@@ -183,6 +183,12 @@ def main():
choices
=
[
"full"
,
"autocast"
],
choices
=
[
"full"
,
"autocast"
],
default
=
"autocast"
default
=
"autocast"
)
)
parser
.
add_argument
(
"--use_int8"
,
type
=
bool
,
default
=
False
,
help
=
"use int8 for inference"
,
)
opt
=
parser
.
parse_args
()
opt
=
parser
.
parse_args
()
seed_everything
(
opt
.
seed
)
seed_everything
(
opt
.
seed
)
...
@@ -193,6 +199,12 @@ def main():
...
@@ -193,6 +199,12 @@ def main():
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
# quantize model
if
opt
.
use_int8
:
model
=
replace_module
(
model
)
# # to compute the model size
# getModelSize(model)
sampler
=
DDIMSampler
(
model
)
sampler
=
DDIMSampler
(
model
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
os
.
makedirs
(
opt
.
outdir
,
exist_ok
=
True
)
...
@@ -280,3 +292,5 @@ def main():
...
@@ -280,3 +292,5 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
# # to compute the mem allocated
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
examples/images/diffusion/scripts/txt2img.py
View file @
1cf6d92d
...
@@ -20,6 +20,7 @@ from ldm.util import instantiate_from_config
...
@@ -20,6 +20,7 @@ from ldm.util import instantiate_from_config
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.ddim
import
DDIMSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.plms
import
PLMSSampler
from
ldm.models.diffusion.dpm_solver
import
DPMSolverSampler
from
ldm.models.diffusion.dpm_solver
import
DPMSolverSampler
from
utils
import
replace_module
,
getModelSize
torch
.
set_grad_enabled
(
False
)
torch
.
set_grad_enabled
(
False
)
...
@@ -43,7 +44,6 @@ def load_model_from_config(config, ckpt, verbose=False):
...
@@ -43,7 +44,6 @@ def load_model_from_config(config, ckpt, verbose=False):
print
(
"unexpected keys:"
)
print
(
"unexpected keys:"
)
print
(
u
)
print
(
u
)
model
.
cuda
()
model
.
eval
()
model
.
eval
()
return
model
return
model
...
@@ -174,6 +174,12 @@ def parse_args():
...
@@ -174,6 +174,12 @@ def parse_args():
default
=
1
,
default
=
1
,
help
=
"repeat each prompt in file this often"
,
help
=
"repeat each prompt in file this often"
,
)
)
parser
.
add_argument
(
"--use_int8"
,
type
=
bool
,
default
=
False
,
help
=
"use int8 for inference"
,
)
opt
=
parser
.
parse_args
()
opt
=
parser
.
parse_args
()
return
opt
return
opt
...
@@ -193,8 +199,15 @@ def main(opt):
...
@@ -193,8 +199,15 @@ def main(opt):
model
=
load_model_from_config
(
config
,
f
"
{
opt
.
ckpt
}
"
)
model
=
load_model_from_config
(
config
,
f
"
{
opt
.
ckpt
}
"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
model
=
model
.
to
(
device
)
model
=
model
.
to
(
device
)
# quantize model
if
opt
.
use_int8
:
model
=
replace_module
(
model
)
# # to compute the model size
# getModelSize(model)
if
opt
.
plms
:
if
opt
.
plms
:
sampler
=
PLMSSampler
(
model
)
sampler
=
PLMSSampler
(
model
)
elif
opt
.
dpm
:
elif
opt
.
dpm
:
...
@@ -290,3 +303,5 @@ def main(opt):
...
@@ -290,3 +303,5 @@ def main(opt):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
opt
=
parse_args
()
opt
=
parse_args
()
main
(
opt
)
main
(
opt
)
# # to compute the mem allocated
# print(torch.cuda.max_memory_allocated() / 1024 / 1024)
examples/images/diffusion/scripts/utils.py
0 → 100644
View file @
1cf6d92d
import
bitsandbytes
as
bnb
import
torch.nn
as
nn
import
torch
class
Linear8bit
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
False
,
memory_efficient_backward
=
False
,
threshold
=
6.0
,
weight_data
=
None
,
bias_data
=
None
):
super
(
Linear8bit
,
self
).
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
bias
=
bias_data
self
.
state
.
threshold
=
threshold
self
.
state
.
has_fp16_weights
=
has_fp16_weights
self
.
state
.
memory_efficient_backward
=
memory_efficient_backward
if
threshold
>
0.0
and
not
has_fp16_weights
:
self
.
state
.
use_pool
=
True
self
.
register_parameter
(
"SCB"
,
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
self
.
weight
=
weight_data
self
.
quant
()
def
quant
(
self
):
weight
=
self
.
weight
.
data
.
contiguous
().
half
().
cuda
()
CB
,
_
,
SCB
,
_
,
_
=
bnb
.
functional
.
double_quant
(
weight
)
delattr
(
self
,
"weight"
)
setattr
(
self
,
"weight"
,
nn
.
Parameter
(
CB
,
requires_grad
=
False
))
delattr
(
self
,
"SCB"
)
setattr
(
self
,
"SCB"
,
nn
.
Parameter
(
SCB
,
requires_grad
=
False
))
del
weight
def
forward
(
self
,
x
):
self
.
state
.
is_training
=
self
.
training
if
self
.
bias
is
not
None
and
self
.
bias
.
dtype
!=
torch
.
float16
:
self
.
bias
.
data
=
self
.
bias
.
data
.
half
()
self
.
state
.
CB
=
self
.
weight
.
data
self
.
state
.
SCB
=
self
.
SCB
.
data
out
=
bnb
.
matmul
(
x
,
self
.
weight
,
bias
=
self
.
bias
,
state
=
self
.
state
)
del
self
.
state
.
CxB
return
out
def
replace_module
(
model
):
for
name
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
replace_module
(
module
)
if
isinstance
(
module
,
nn
.
Linear
)
and
"out_proj"
not
in
name
:
model
.
_modules
[
name
]
=
Linear8bit
(
input_features
=
module
.
in_features
,
output_features
=
module
.
out_features
,
threshold
=
6.0
,
weight_data
=
module
.
weight
,
bias_data
=
module
.
bias
,
)
return
model
def
getModelSize
(
model
):
param_size
=
0
param_sum
=
0
for
param
in
model
.
parameters
():
param_size
+=
param
.
nelement
()
*
param
.
element_size
()
param_sum
+=
param
.
nelement
()
buffer_size
=
0
buffer_sum
=
0
for
buffer
in
model
.
buffers
():
buffer_size
+=
buffer
.
nelement
()
*
buffer
.
element_size
()
buffer_sum
+=
buffer
.
nelement
()
all_size
=
(
param_size
+
buffer_size
)
/
1024
/
1024
print
(
'Model Size: {:.3f}MB'
.
format
(
all_size
))
return
(
param_size
,
param_sum
,
buffer_size
,
buffer_sum
,
all_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment