Unverified Commit 59dcea3f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PreTrainedModel`] Wrap `cuda` and `to` method correctly (#25206)

wrap `cuda` and `to` method correctly
parent 67b85f24
......@@ -25,7 +25,7 @@ import tempfile
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
......@@ -1912,6 +1912,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mem = mem + mem_bufs
return mem
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False):
......@@ -1922,6 +1923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else:
return super().cuda(*args, **kwargs)
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment