pretrained.py 3.17 KB
Newer Older
1
2
import os
from functools import partial
3

4
5
import torch
from safetensors.torch import load_file as safe_load_file
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
11
from transformers.utils import (
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
)
Tri Dao's avatar
Tri Dao committed
12
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
13
14


15
def state_dict_from_pretrained(model_name, device=None, dtype=None):
16
    # If not fp32, then we don't want to load directly to the GPU
Tri Dao's avatar
Tri Dao committed
17
    mapped_device = "cpu" if dtype not in [torch.float32, None] else device
Tri Dao's avatar
Tri Dao committed
18
    is_sharded = False
19
20
21
22
23
24
25
26
27
    load_safe = False
    resolved_archive_file = None

    weights_path = os.path.join(model_name, WEIGHTS_NAME)
    weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
    safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
    safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)

    if os.path.isfile(weights_path):
Tri Dao's avatar
Tri Dao committed
28
29
30
        resolved_archive_file = cached_file(
            model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
        )
31
    elif os.path.isfile(weights_index_path):
Tri Dao's avatar
Tri Dao committed
32
33
34
        resolved_archive_file = cached_file(
            model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
        )
35
36
        is_sharded = True
    elif os.path.isfile(safe_weights_path):
Tri Dao's avatar
Tri Dao committed
37
38
39
        resolved_archive_file = cached_file(
            model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
        )
40
41
        load_safe = True
    elif os.path.isfile(safe_weights_index_path):
Tri Dao's avatar
Tri Dao committed
42
43
44
        resolved_archive_file = cached_file(
            model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
        )
45
46
        is_sharded = True
        load_safe = True
47
48
49
50
51
52
53
54
    else:  # Try loading from HF hub instead of from local files
        resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                            _raise_exceptions_for_missing_entries=False)
        if resolved_archive_file is None:
            resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            if resolved_archive_file is not None:
                is_sharded = True
55

Tri Dao's avatar
Tri Dao committed
56
57
    if resolved_archive_file is None:
        raise EnvironmentError(f"Model name {model_name} was not found.")
58
59
60
61
62
63

    if load_safe:
        loader = partial(safe_load_file, device=mapped_device)
    else:
        loader = partial(torch.load, map_location=mapped_device)

Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
69
70
71
    if is_sharded:
        # resolved_archive_file becomes a list of files that point to the different
        # checkpoint shards in this case.
        resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
            model_name, resolved_archive_file
        )
        state_dict = {}
        for sharded_file in resolved_archive_file:
72
            state_dict.update(loader(sharded_file))
Tri Dao's avatar
Tri Dao committed
73
    else:
74
        state_dict = loader(resolved_archive_file)
75
    # Convert dtype before moving to GPU to save memory
76
    if dtype is not None:
77
78
        state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
    state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
79
    return state_dict