test_layers.py 2.47 KB
Newer Older
jixx's avatar
init  
jixx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from text_generation_server.layers import (
    TensorParallelEmbedding,
)


class ProcessGroup:
    def __init__(self, rank: int, world_size: int):
        self._rank = rank
        self.world_size = world_size

    def size(self) -> int:
        return self.world_size

    def rank(self) -> int:
        return self._rank


class Weights:
    def __init__(self, rank: int, world_size: int, vocab_size: int, hidden_dim: int):
        self.weight = (
            torch.arange(vocab_size * hidden_dim).float().view(vocab_size, hidden_dim)
        )
        self.process_group = ProcessGroup(rank, world_size)

    def get_partial_sharded(self, name: str, dim: int):
        assert dim == 0

        rank = self.process_group.rank()
        world_size = self.process_group.size()
        size = self.weight.shape[dim]

        block_size = (size + world_size - 1) // world_size
        start = rank * block_size
        stop = (rank + 1) * block_size
        return self.weight[start:stop]

    def get_shape(self, name: str):
        return self.weight.shape


def test_weight_hub_files_offline_error():

    vocab_size = 17
    weights = Weights(rank=0, world_size=1, vocab_size=vocab_size, hidden_dim=256)
    embeddings = TensorParallelEmbedding("", weights)

    input_ids = torch.arange(vocab_size)
    output = embeddings.forward(input_ids)
    assert embeddings.min_id == 0
    assert embeddings.max_id == 17
    torch.testing.assert_close(output, torch.arange(256 * 17).float().view(17, 256))

    weights_0_2 = Weights(rank=0, world_size=2, vocab_size=vocab_size, hidden_dim=256)
    weights_1_2 = Weights(rank=1, world_size=2, vocab_size=vocab_size, hidden_dim=256)
    embeddings_0_2 = TensorParallelEmbedding("", weights_0_2, reduce=False)
    assert embeddings_0_2.min_id == 0
    assert embeddings_0_2.max_id == 9
    torch.testing.assert_close(
        embeddings_0_2.weight,
        torch.cat([torch.arange(9 * 256), torch.zeros(256)], dim=0)
        .view(10, 256)
        .float(),
    )
    embeddings_1_2 = TensorParallelEmbedding("", weights_1_2, reduce=False)
    assert embeddings_1_2.min_id == 9
    assert embeddings_1_2.max_id == 17
    torch.testing.assert_close(
        embeddings_1_2.weight,
        torch.cat([torch.arange(8 * 256) + 9 * 256, torch.zeros(256)], dim=0)
        .view(9, 256)
        .float(),
    )
    output_tp_0 = embeddings_0_2.forward(input_ids)
    output_tp_1 = embeddings_1_2.forward(input_ids)

    torch.testing.assert_close(output, output_tp_0 + output_tp_1)