pretrained.py 212 Bytes
Newer Older
1
2
3
4
5
6
7
8
import torch

from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file


def state_dict_from_pretrained(model_name):
    return torch.load(cached_file(model_name, WEIGHTS_NAME))