"images/git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "3057567771e9878e1102bf9cf50e0b1b7e322bd4"
Unverified Commit 2a787201 authored by YQ's avatar YQ Committed by GitHub
Browse files

override .cuda() to check if model is already quantized (#25166)

parent c1dba111
......@@ -1912,6 +1912,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mem = mem + mem_bufs
return mem
def cuda(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False):
raise ValueError(
"Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
return super().cuda(*args, **kwargs)
def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment