"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "93579650f8f3fd2c49d665c7dc582d5111583d02"
Unverified Commit 98d0cd57 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Use torch.device instead of current device index for BnB quantizer (#10069)



* update

* apply review suggestion

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0d11ab26
...@@ -176,6 +176,8 @@ def load_model_dict_into_meta( ...@@ -176,6 +176,8 @@ def load_model_dict_into_meta(
hf_quantizer=None, hf_quantizer=None,
keep_in_fp32_modules=None, keep_in_fp32_modules=None,
) -> List[str]: ) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None: if hf_quantizer is None:
device = device or torch.device("cpu") device = device or torch.device("cpu")
dtype = dtype or torch.float32 dtype = dtype or torch.float32
......
...@@ -836,7 +836,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -836,7 +836,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
param_device = "cpu" param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor # TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb: elif is_quant_method_bnb:
param_device = torch.cuda.current_device() param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant) state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict) model._convert_deprecated_attention_blocks(state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment