"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e643a297228c8cb2c189fe4c93e11125f938d20b"
Unverified Commit 59dcea3f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PreTrainedModel`] Wrap `cuda` and `to` method correctly (#25206)

wrap `cuda` and `to` method correctly
parent 67b85f24
...@@ -25,7 +25,7 @@ import tempfile ...@@ -25,7 +25,7 @@ import tempfile
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -1912,6 +1912,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1912,6 +1912,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mem = mem + mem_bufs mem = mem + mem_bufs
return mem return mem
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs): def cuda(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit # Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False): if getattr(self, "is_quantized", False):
...@@ -1922,6 +1923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1922,6 +1923,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
else: else:
return super().cuda(*args, **kwargs) return super().cuda(*args, **kwargs)
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit # Checks if the model has been loaded in 8-bit
if getattr(self, "is_quantized", False): if getattr(self, "is_quantized", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment