Unverified Commit 0b065c09 authored by hlky's avatar hlky Committed by GitHub
Browse files

Move buffers to device (#10523)

* Move buffers to device

* add test

* named_buffers
parent b785ddb6
......@@ -362,6 +362,7 @@ class FromOriginalModelMixin:
if is_accelerate_available():
param_device = torch.device(device) if device else torch.device("cpu")
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
......@@ -369,6 +370,7 @@ class FromOriginalModelMixin:
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)
else:
......
......@@ -20,7 +20,7 @@ import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union
import safetensors
import torch
......@@ -193,6 +193,7 @@ def load_model_dict_into_meta(
model_name_or_path: Optional[str] = None,
hf_quantizer=None,
keep_in_fp32_modules=None,
named_buffers: Optional[Iterator[Tuple[str, torch.Tensor]]] = None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
......@@ -254,6 +255,20 @@ def load_model_dict_into_meta(
else:
set_module_tensor_to_device(model, param_name, device, value=param)
if named_buffers is None:
return unexpected_keys
for param_name, param in named_buffers:
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys
......
......@@ -913,6 +913,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" those weights or else make sure your checkpoint file is correct."
)
named_buffers = model.named_buffers()
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
......@@ -921,6 +923,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
named_buffers=named_buffers,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
......
......@@ -20,7 +20,14 @@ import numpy as np
import pytest
from huggingface_hub import hf_hub_download
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
......@@ -302,6 +309,33 @@ class BnB8bitBasicTests(Base8bitTests):
_ = self.model_fp16.cuda()
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()
mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16_diffusers",
subfolder="transformer",
quantization_config=mixed_int8_config,
)
def tearDown(self):
del self.model_8bit
gc.collect()
torch.cuda.empty_cache()
def test_buffers_device_assignment(self):
for buffer_name, buffer in self.model_8bit.named_buffers():
self.assertEqual(
buffer.device.type,
torch.device(torch_device).type,
f"Expected device {torch_device} for {buffer_name} got {buffer.device}.",
)
class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment