pretrained.py 246 Bytes
Newer Older
1
2
3
4
5
6
import torch

from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file


7
8
def state_dict_from_pretrained(model_name, device=None):
    return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)