weights.py 3.05 KB
Newer Older
1
from pathlib import Path
2
from typing import List, Dict, Optional
3
4
5
6
from safetensors import safe_open


class Weights:
7
    def __init__(self, filenames: List[Path], device, dtype, process_group, aliases: Optional[Dict[str, List[str]]]=None):
8
9
10
11
12
13
14
15
16
        routing = {}
        for filename in filenames:
            with safe_open(filename, framework="pytorch") as f:
                for k in f.keys():
                    if k in routing:
                        raise RuntimeError(
                            f"Key {k} was found in multiple files: {filename} and {routing[k]}"
                        )
                    routing[k] = filename
17
18
19
        if aliases is None:
            aliases = {}
        self.aliases = aliases
20
21
22
23
24
25
26
27
28
29
30
31
32
        self.routing = routing
        self.device = device
        self.dtype = dtype
        self.process_group = process_group
        self._handles = {}

    def _get_handle(self, filename):
        if filename not in self._handles:
            f = safe_open(filename, framework="pytorch")
            self._handles[filename] = f

        return self._handles[filename]

33
    def get_filename(self, tensor_name: str) -> (str, str):
34
35
        filename = self.routing.get(tensor_name, None)
        if filename is None:
36
37
38
39
40
            aliases = self.aliases.get(tensor_name, [])
            for alias in aliases:
                filename = self.routing.get(alias, None)
                if filename is not None:
                    return str(filename), alias
41
            raise RuntimeError(f"weight {tensor_name} does not exist")
42
        return str(filename), tensor_name
43
44

    def _get_slice(self, tensor_name: str):
45
        filename, tensor_name= self.get_filename(tensor_name)
46
47
48
49
50
51
52
53
        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        return slice_

    def get_shape(self, tensor_name: str):
        return self._get_slice(tensor_name).get_shape()

    def get_tensor(self, tensor_name: str):
54
        filename, tensor_name = self.get_filename(tensor_name)
55
56
57
58
59
60
61
        f = self._get_handle(filename)
        tensor = f.get_tensor(tensor_name)
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
        return tensor

    def get_sharded(self, tensor_name: str, dim: int):
62
        filename, tensor_name = self.get_filename(tensor_name)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        world_size = self.process_group.size()
        rank = self.process_group.rank()

        f = self._get_handle(filename)
        slice_ = f.get_slice(tensor_name)
        size = slice_.get_shape()[dim]
        block_size = size // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size

        assert (
            size % world_size == 0
        ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"

        if dim == 0:
            tensor = slice_[start:stop]
        elif dim == 1:
            tensor = slice_[:, start:stop]
        else:
            raise NotImplementedError("Let's make that generic when needed")
        tensor = tensor.to(dtype=self.dtype)
        tensor = tensor.to(device=self.device)
        return tensor