test_layers.py 2.51 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
from text_generation_server.layers import (
3
4
5
    TensorParallelEmbedding,
)

OlivierDehaene's avatar
OlivierDehaene committed
6

7
8
9
10
11
class ProcessGroup:
    def __init__(self, rank: int, world_size: int):
        self._rank = rank
        self.world_size = world_size

OlivierDehaene's avatar
OlivierDehaene committed
12
    def size(self) -> int:
13
14
        return self.world_size

OlivierDehaene's avatar
OlivierDehaene committed
15
    def rank(self) -> int:
16
17
        return self._rank

OlivierDehaene's avatar
OlivierDehaene committed
18

19
20
class Weights:
    def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
OlivierDehaene's avatar
OlivierDehaene committed
21
22
23
        self.weight = (
            torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
        )
24
25
        self.process_group = ProcessGroup(rank, world_size)

OlivierDehaene's avatar
OlivierDehaene committed
26
    def get_partial_sharded(self, name: str, dim: int):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        assert dim == 0

        rank = self.process_group.rank()
        world_size = self.process_group.size()
        size = self.weight.shape[dim]

        block_size = (size + world_size - 1) // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        return self.weight[start:stop]

    def get_shape(self, name: str):
        return self.weight.shape

OlivierDehaene's avatar
OlivierDehaene committed
41

42
43
def test_weight_hub_files_offline_error():

OlivierDehaene's avatar
OlivierDehaene committed
44
    vocab_size = 17
45
46
47
48
49
50
    weights = Weights(
        rank=0,
        world_size=1,
        vocab_size=vocab_size,
        hidden_dim=256,
    )
51
52
53
54
55
56
57
58
    embeddings = TensorParallelEmbedding("", weights)

    input_ids = torch.arange(vocab_size)
    output = embeddings.forward(input_ids)
    assert embeddings.min_id == 0
    assert embeddings.max_id == 17
    torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))

OlivierDehaene's avatar
OlivierDehaene committed
59
60
    weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
    weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
61
62
63
    embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
    assert embeddings_0_2.min_id == 0
    assert embeddings_0_2.max_id == 9
OlivierDehaene's avatar
OlivierDehaene committed
64
65
66
67
68
69
    torch.testing.assert_close(
        embeddings_0_2.weight,
        torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
        .view(10, 256)
        .float(),
    )
70
71
72
    embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
    assert embeddings_1_2.min_id == 9
    assert embeddings_1_2.max_id == 17
OlivierDehaene's avatar
OlivierDehaene committed
73
74
75
76
77
78
    torch.testing.assert_close(
        embeddings_1_2.weight,
        torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
        .view(9, 256)
        .float(),
    )
79
80
81
82
    output_tp_0 = embeddings_0_2.forward(input_ids)
    output_tp_1 = embeddings_1_2.forward(input_ids)

    torch.testing.assert_close(output, output_tp_0 + output_tp_1)