Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
0b065c09
Unverified
Commit
0b065c09
authored
Jan 16, 2025
by
hlky
Committed by
GitHub
Jan 16, 2025
Browse files
Move buffers to device (#10523)
* Move buffers to device * add test * named_buffers
parent
b785ddb6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
56 additions
and
2 deletions
+56
-2
src/diffusers/loaders/single_file_model.py
src/diffusers/loaders/single_file_model.py
+2
-0
src/diffusers/models/model_loading_utils.py
src/diffusers/models/model_loading_utils.py
+16
-1
src/diffusers/models/modeling_utils.py
src/diffusers/models/modeling_utils.py
+3
-0
tests/quantization/bnb/test_mixed_int8.py
tests/quantization/bnb/test_mixed_int8.py
+35
-1
No files found.
src/diffusers/loaders/single_file_model.py
View file @
0b065c09
...
...
@@ -362,6 +362,7 @@ class FromOriginalModelMixin:
if
is_accelerate_available
():
param_device
=
torch
.
device
(
device
)
if
device
else
torch
.
device
(
"cpu"
)
named_buffers
=
model
.
named_buffers
()
unexpected_keys
=
load_model_dict_into_meta
(
model
,
diffusers_format_checkpoint
,
...
...
@@ -369,6 +370,7 @@ class FromOriginalModelMixin:
device
=
param_device
,
hf_quantizer
=
hf_quantizer
,
keep_in_fp32_modules
=
keep_in_fp32_modules
,
named_buffers
=
named_buffers
,
)
else
:
...
...
src/diffusers/models/model_loading_utils.py
View file @
0b065c09
...
...
@@ -20,7 +20,7 @@ import os
from
array
import
array
from
collections
import
OrderedDict
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
safetensors
import
torch
...
...
@@ -193,6 +193,7 @@ def load_model_dict_into_meta(
model_name_or_path
:
Optional
[
str
]
=
None
,
hf_quantizer
=
None
,
keep_in_fp32_modules
=
None
,
named_buffers
:
Optional
[
Iterator
[
Tuple
[
str
,
torch
.
Tensor
]]]
=
None
,
)
->
List
[
str
]:
if
device
is
not
None
and
not
isinstance
(
device
,
(
str
,
torch
.
device
)):
raise
ValueError
(
f
"Expected device to have type `str` or `torch.device`, but got
{
type
(
device
)
=
}
."
)
...
...
@@ -254,6 +255,20 @@ def load_model_dict_into_meta(
else
:
set_module_tensor_to_device
(
model
,
param_name
,
device
,
value
=
param
)
if
named_buffers
is
None
:
return
unexpected_keys
for
param_name
,
param
in
named_buffers
:
if
is_quantized
and
(
hf_quantizer
.
check_if_quantized_param
(
model
,
param
,
param_name
,
state_dict
,
param_device
=
device
)
):
hf_quantizer
.
create_quantized_param
(
model
,
param
,
param_name
,
device
,
state_dict
,
unexpected_keys
)
else
:
if
accepts_dtype
:
set_module_tensor_to_device
(
model
,
param_name
,
device
,
value
=
param
,
**
set_module_kwargs
)
else
:
set_module_tensor_to_device
(
model
,
param_name
,
device
,
value
=
param
)
return
unexpected_keys
...
...
src/diffusers/models/modeling_utils.py
View file @
0b065c09
...
...
@@ -913,6 +913,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" those weights or else make sure your checkpoint file is correct."
)
named_buffers
=
model
.
named_buffers
()
unexpected_keys
=
load_model_dict_into_meta
(
model
,
state_dict
,
...
...
@@ -921,6 +923,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model_name_or_path
=
pretrained_model_name_or_path
,
hf_quantizer
=
hf_quantizer
,
keep_in_fp32_modules
=
keep_in_fp32_modules
,
named_buffers
=
named_buffers
,
)
if
cls
.
_keys_to_ignore_on_load_unexpected
is
not
None
:
...
...
tests/quantization/bnb/test_mixed_int8.py
View file @
0b065c09
...
...
@@ -20,7 +20,14 @@ import numpy as np
import
pytest
from
huggingface_hub
import
hf_hub_download
from
diffusers
import
BitsAndBytesConfig
,
DiffusionPipeline
,
FluxTransformer2DModel
,
SD3Transformer2DModel
,
logging
from
diffusers
import
(
BitsAndBytesConfig
,
DiffusionPipeline
,
FluxTransformer2DModel
,
SanaTransformer2DModel
,
SD3Transformer2DModel
,
logging
,
)
from
diffusers.utils
import
is_accelerate_version
from
diffusers.utils.testing_utils
import
(
CaptureLogger
,
...
...
@@ -302,6 +309,33 @@ class BnB8bitBasicTests(Base8bitTests):
_
=
self
.
model_fp16
.
cuda
()
class
Bnb8bitDeviceTests
(
Base8bitTests
):
def
setUp
(
self
)
->
None
:
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
mixed_int8_config
=
BitsAndBytesConfig
(
load_in_8bit
=
True
)
self
.
model_8bit
=
SanaTransformer2DModel
.
from_pretrained
(
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers"
,
subfolder
=
"transformer"
,
quantization_config
=
mixed_int8_config
,
)
def
tearDown
(
self
):
del
self
.
model_8bit
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
test_buffers_device_assignment
(
self
):
for
buffer_name
,
buffer
in
self
.
model_8bit
.
named_buffers
():
self
.
assertEqual
(
buffer
.
device
.
type
,
torch
.
device
(
torch_device
).
type
,
f
"Expected device
{
torch_device
}
for
{
buffer_name
}
got
{
buffer
.
device
}
."
,
)
class
BnB8bitTrainingTests
(
Base8bitTests
):
def
setUp
(
self
):
gc
.
collect
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment