Unverified Commit 2a787201 authored by YQ's avatar YQ Committed by GitHub
Browse files

override .cuda() to check if model is already quantized (#25166)

parent c1dba111
...@@ -1912,6 +1912,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1912,6 +1912,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mem = mem + mem_bufs mem = mem + mem_bufs
return mem return mem
def cuda(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
return super().cuda(*args, **kwargs)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit # Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False): if getattr(self, "is_quantized", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment