vocal_parallel_embedding.py 4.69 KB
Newer Older
Jee Jee Li's avatar
Jee Jee Li committed
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

10
from vllm.config.lora import LoRAConfig
11
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
Jee Jee Li's avatar
Jee Jee Li committed
12
13
14
15
16
17
18
19
20
from vllm.platforms import current_platform

from .base import BaseLayerWithLoRA


class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: VocabParallelEmbedding) -> None:
        super().__init__()
        self.base_layer = base_layer
21
22
        self.embeddings_slice: tuple[int, int] | None
        self.embeddings_weights: torch.Tensor | None
Jee Jee Li's avatar
Jee Jee Li committed
23
24

    def create_lora_weights(
25
26
27
        self,
        max_loras: int,
        lora_config: LoRAConfig,
28
        model_config: PretrainedConfig | None = None,
29
    ) -> None:
Jee Jee Li's avatar
Jee Jee Li committed
30
31
32
        if self.base_layer.num_added_embeddings_per_partition > 0:
            # We can start adding lora weights
            self.embeddings_weights = self.base_layer.weight.data[
33
                self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition  # noqa: E501
34
35
                + self.base_layer.num_added_embeddings_per_partition
            ]
Jee Jee Li's avatar
Jee Jee Li committed
36
            self.embeddings_slice = (
37
38
39
40
41
                self.base_layer.shard_indices.added_vocab_start_index
                - self.base_layer.org_vocab_size,
                self.base_layer.shard_indices.added_vocab_end_index
                - self.base_layer.org_vocab_size,
            )
Jee Jee Li's avatar
Jee Jee Li committed
42
            self.base_layer.weight.data[
43
44
                self.base_layer.num_org_embeddings_per_partition :
            ].fill_(0)
Jee Jee Li's avatar
Jee Jee Li committed
45
46
47
48
49
50
51
        else:
            self.embeddings_slice = None
            self.embeddings_weights = None

        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
52
                self.base_layer.org_vocab_size,
Jee Jee Li's avatar
Jee Jee Li committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.embedding_dim,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.base_layer.weight.device,
        )
        self.lora_a_stacked_2d = self.lora_a_stacked.view(
            self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
            self.lora_a_stacked.shape[2],
        )

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor,
        lora_b: torch.Tensor,
    ):
        self.reset_lora(index)
84
85
        # NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
        # so we need transpose here
86
87
88
89
90
91
        self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
            lora_a.T, non_blocking=True
        )
        self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )
Jee Jee Li's avatar
Jee Jee Li committed
92
93
94
95
96
97
98
99
100
101
102

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # NB: Don't use torch.narrow here. torch.narrow triggers some
        # Dynamic Shape specialization in torch.compile
        num_tokens = x.shape[0]
        indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]

        full_lora_a_embeddings = F.embedding(
            x + indices_1,
            self.lora_a_stacked_2d,
        )
103
        full_output = self.base_layer.forward(x)
Jee Jee Li's avatar
Jee Jee Li committed
104
105
106
107

        full_output_org = full_output
        if full_output.ndim == 3:
            full_output = full_output.view(
108
109
                full_output.shape[0] * full_output.shape[1], -1
            )
Jee Jee Li's avatar
Jee Jee Li committed
110
111
        if full_lora_a_embeddings.ndim == 3:
            full_lora_a_embeddings = full_lora_a_embeddings.view(
112
                full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1],
Jee Jee Li's avatar
Jee Jee Li committed
113
114
115
                -1,
            )

116
        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding(
117
118
            full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True
        )
Jee Jee Li's avatar
Jee Jee Li committed
119
120
121
122
123
124
125
126
127
128
129
130

        if not current_platform.can_update_inplace():
            full_output = lora_output

        return full_output.view_as(full_output_org)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
131
        model_config: PretrainedConfig | None,
Jee Jee Li's avatar
Jee Jee Li committed
132
133
134
135
136
137
    ) -> bool:
        return type(source_layer) is VocabParallelEmbedding

    @property
    def weight(self):
        return self.base_layer.weight