"vscode:/vscode.git/clone" did not exist on "c84d52cda451718dfd77936066ea9587c2f9bd2c"
Unverified Commit 0b065c09 authored by hlky's avatar hlky Committed by GitHub
Browse files

Move buffers to device (#10523)

* Move buffers to device

* add test

* named_buffers
parent b785ddb6
...@@ -362,6 +362,7 @@ class FromOriginalModelMixin: ...@@ -362,6 +362,7 @@ class FromOriginalModelMixin:
if is_accelerate_available(): if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu") param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta( unexpected_keys = load_model_dict_into_meta(
model, model,
diffusers_format_checkpoint, diffusers_format_checkpoint,
...@@ -369,6 +370,7 @@ class FromOriginalModelMixin: ...@@ -369,6 +370,7 @@ class FromOriginalModelMixin:
device=param_device, device=param_device,
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
) )
else: else:
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
from array import array from array import array
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
import safetensors import safetensors
import torch import torch
...@@ -193,6 +193,7 @@ def load_model_dict_into_meta( ...@@ -193,6 +193,7 @@ def load_model_dict_into_meta(
model_name_or_path: Optional[str] = None, model_name_or_path: Optional[str] = None,
hf_quantizer=None, hf_quantizer=None,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
) -> List[str]: ) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)): if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
...@@ -254,6 +255,20 @@ def load_model_dict_into_meta( ...@@ -254,6 +255,20 @@ def load_model_dict_into_meta(
else: else:
set_module_tensor_to_device(model, param_name, device, value=param) set_module_tensor_to_device(model, param_name, device, value=param)
if named_buffers is None:
return unexpected_keys
for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys return unexpected_keys
......
...@@ -913,6 +913,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -913,6 +913,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" those weights or else make sure your checkpoint file is correct." " those weights or else make sure your checkpoint file is correct."
) )
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta( unexpected_keys = load_model_dict_into_meta(
model, model,
state_dict, state_dict,
...@@ -921,6 +923,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -921,6 +923,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model_name_or_path=pretrained_model_name_or_path, model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer, hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules, keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
) )
if cls._keys_to_ignore_on_load_unexpected is not None: if cls._keys_to_ignore_on_load_unexpected is not None:
......
...@@ -20,7 +20,14 @@ import numpy as np ...@@ -20,7 +20,14 @@ import numpy as np
import pytest import pytest
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
from diffusers.utils import is_accelerate_version from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
...@@ -302,6 +309,33 @@ class BnB8bitBasicTests(Base8bitTests): ...@@ -302,6 +309,33 @@ class BnB8bitBasicTests(Base8bitTests):
_ = self.model_fp16.cuda() _ = self.model_fp16.cuda()
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
)
def tearDown(self):
del self.model_8bit
gc.collect()
torch.cuda.empty_cache()
def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers():
self.assertEqual(
buffer.device.type,
torch.device(torch_device).type,
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
)
class BnB8bitTrainingTests(Base8bitTests): class BnB8bitTrainingTests(Base8bitTests):
def setUp(self): def setUp(self):
gc.collect() gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment