pretrained.py 381 Bytes
Newer Older
1
2
3
4
5
6
import torch

from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file


7
8
9
10
11
def state_dict_from_pretrained(model_name, device=None, dtype=None):
    state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
    if dtype is not None:
        state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
    return state_dict