Unverified Commit 98d0cd57 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Use torch.device instead of current device index for BnB quantizer (#10069)



* update

* apply review suggestion

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0d11ab26
......@@ -176,6 +176,8 @@ def load_model_dict_into_meta(
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None:
device = device or torch.device("cpu")
dtype = dtype or torch.float32
......
......@@ -836,7 +836,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
param_device = torch.cuda.current_device()
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment