activation.py 15.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom activation functions."""
4
import math
5
from typing import Optional
zhuwenwen's avatar
zhuwenwen committed
6
import optimus
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
import torch
import torch.nn as nn
10
import torch.nn.functional as F
Woosuk Kwon's avatar
Woosuk Kwon committed
11

12
13
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
14
from vllm.model_executor.custom_op import CustomOp
15
from vllm.model_executor.utils import set_weight_attrs
16
from vllm.platforms import current_platform
17
from vllm.utils import LazyDict
zhuwenwen's avatar
zhuwenwen committed
18
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20


21
@CustomOp.register("fatrelu_and_mul")
22
23
class FatreluAndMul(CustomOp):
    """An activation function for FATReLU.
24

25
26
27
28
29
30
31
32
33
34
35
36
    The function computes x -> FATReLU(x[:d]) * x[d:] where
    d = x.shape[-1] // 2.
    This is used in openbmb/MiniCPM-S-1B-sft.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, threshold: float = 0.):
        super().__init__()
        self.threshold = threshold
37
        if current_platform.is_cuda_alike():
38
            self.op = torch.ops._C.fatrelu_and_mul
39
40
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
41
42
43
44
45
46
47
48
49

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x1 = x[..., :d]
        x2 = x[..., d:]
        x1 = F.threshold(x1, self.threshold, 0.0)
        return x1 * x2

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
50
51
52
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
53
        self.op(out, x, self.threshold)
54
        return out
55
56


zhuwenwen's avatar
zhuwenwen committed
57
58
59
60
61
62
63
64
class OptimusSiluAndMul(nn.Module):

    def forward(self,
                x: torch.Tensor,
                output: Optional[torch.Tensor] = None) -> torch.Tensor:
        return torch.ops.Optimus.SiluDot_forward(x, out=output)
    
    
65
@CustomOp.register("silu_and_mul")
66
class SiluAndMul(CustomOp):
67
68
    """An activation function for SwiGLU.

69
    The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
    Shapes:
72
73
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
74
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
75

76
77
    def __init__(self):
        super().__init__()
78
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
79
            self.op = torch.ops._C.silu_and_mul
zhuwenwen's avatar
zhuwenwen committed
80
            self.op_opt = torch.ops._C.silu_and_mul_opt
81
        elif current_platform.is_xpu():
82
83
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
84

85
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
86
        """PyTorch-native implementation equivalent to forward()."""
maxiao1's avatar
maxiao1 committed
87
88
        if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
            return self.forward_cuda(x)
maxiao1's avatar
maxiao1 committed
89
90
91
        else:
            d = x.shape[-1] // 2
            return F.silu(x[..., :d]) * x[..., d:]
92

93
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
94
95
96
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
97
        if envs.VLLM_USE_OPT_OP:
zhuwenwen's avatar
zhuwenwen committed
98
            self.op_opt(out, x)
zhuwenwen's avatar
zhuwenwen committed
99
        else:
zhuwenwen's avatar
zhuwenwen committed
100
            self.op(out, x) 
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        return out
102

103
104
105
106
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
107
        self.op(out, x)
108
109
        return out

110
111
112
113
114
115
116
    def forward_neuron(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        x_reshaped = x.view(-1, x.shape[-1])
        s = x_reshaped[:, :d] * F.sigmoid(x_reshaped[:, :d])
        result = s * x_reshaped[:, d:]
        return result.view(*x.shape[:-1], d)

117

118
119
120
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
    """An activation function for SwiGLU.
121

122
123
124
125
126
127
128
129
130
    The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.

    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self):
        super().__init__()
131
        if current_platform.is_cuda_alike():
132
133
134
135
            self.op = torch.ops._C.mul_and_silu
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.silu_and_mul
136
137
        elif current_platform.is_cpu():
            self._forward_method = self.forward_native
138
139
140
141
142
143
144

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        return x[..., :d] * F.silu(x[..., d:])

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
145
146
147
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
148
        self.op(out, x)
149
150
        return out

151
152
153
    # TODO implement forward_xpu for MulAndSilu
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

154

Robert Shaw's avatar
Robert Shaw committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
    """An activation function for GeluAndMulSparse.
    This activation function is used in Gemma3n. It computes:
        up_proj = self.up_proj(x)
        gate_proj = self.gate_proj(x)
        gate_proj = self._gaussian_topk(gate_proj) # sparsity
        activations = self.act_fn(gate_proj) # gelu
        down_proj = self.down_proj(activations * up_proj)
    Shapes:
        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
        return: (num_tokens, d) or (batch_size, seq_len, d)
    """

    def __init__(self, activation_sparsity: float, approximate: str = "none"):
        super().__init__()
        # Gelu.
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")

        # Sparsity.
        if activation_sparsity == 0.0:
            raise ValueError(
                "activation_sparsity is 0.0. Please use GeluAndMul.")
        target_sparsity_tensor = torch.tensor(activation_sparsity,
                                              dtype=torch.float32)
        normal_dist = torch.distributions.normal.Normal(0, 1)
        self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)

    def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
        """Get % sparse percentile of the Gaussian distribution."""
        # NOTE(rob): for TP>1, we could all-gather to get the means/std.
        # But we do not do this because in expectation they are the same
        # and in practice the eval scores are good without gathering.
        mean = torch.mean(x, dim=-1, keepdim=True)
        std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
        cutoff_x = mean + std * self.std_multiplier
        return nn.functional.relu(x - cutoff_x)

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
        out = self._gaussian_topk(x[..., :d])
        out = F.gelu(out, approximate=self.approximate)
        return out * x[..., d:]

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


206
@CustomOp.register("gelu_and_mul")
207
class GeluAndMul(CustomOp):
208
209
210
211
212
213
214
215
216
    """An activation function for GeGLU.

    The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.

    Shapes:
        x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
        return: (batch_size, seq_len, d) or (num_tokens, d)
    """

217
218
219
220
221
    def __init__(self, approximate: str = "none"):
        super().__init__()
        self.approximate = approximate
        if approximate not in ("none", "tanh"):
            raise ValueError(f"Unknown approximate mode: {approximate}")
222
223
224
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            if approximate == "none":
                self.op = torch.ops._C.gelu_and_mul
zhuwenwen's avatar
zhuwenwen committed
225
                self.op_opt = torch.ops._C.gelu_and_mul_opt
226
227
            elif approximate == "tanh":
                self.op = torch.ops._C.gelu_tanh_and_mul
zhuwenwen's avatar
zhuwenwen committed
228
                self.op_opt = torch.ops._C.gelu_tanh_and_mul_opt
229
230
231
232
233
234
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            if approximate == "none":
                self.op = ipex_ops.gelu_and_mul
            else:
                self.op = ipex_ops.gelu_tanh_and_mul
235

236
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
237
238
        """PyTorch-native implementation equivalent to forward()."""
        d = x.shape[-1] // 2
239
        return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
240

241
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
242
243
244
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
zhuwenwen's avatar
zhuwenwen committed
245
246
247
248
        if envs.VLLM_USE_OPT_OP: 
            self.op_opt(out, x)
        else:
            self.op(out, x)
249
250
        return out

251
252
253
254
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        d = x.shape[-1] // 2
        output_shape = (x.shape[:-1] + (d, ))
        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
255
        self.op(out, x)
256
257
        return out

258
259
260
    def extra_repr(self) -> str:
        return f'approximate={repr(self.approximate)}'

261

262
@CustomOp.register("gelu_new")
263
class NewGELU(CustomOp):
264

265
266
267
268
269
270
271
272
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_new
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_new

273
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
274
275
276
277
278
        """PyTorch-native implementation equivalent to forward()."""
        c = math.sqrt(2.0 / math.pi)
        return 0.5 * x * (1.0 + torch.tanh(c *
                                           (x + 0.044715 * torch.pow(x, 3.0))))

279
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
280
        out = torch.empty_like(x)
281
        self.op(out, x)
282
283
        return out

284
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
285
        return self.op(x)
286

287

288
@CustomOp.register("gelu_fast")
289
class FastGELU(CustomOp):
290

291
292
293
294
295
296
297
298
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_fast
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_fast

299
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
300
301
302
303
        """PyTorch-native implementation equivalent to forward()."""
        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                           (1.0 + 0.044715 * x * x)))

304
    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
305
        out = torch.empty_like(x)
306
        self.op(out, x)
307
308
        return out

309
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
310
        return self.op(x)
311

312

313
@CustomOp.register("quick_gelu")
314
315
class QuickGELU(CustomOp):
    # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
316
317
318
319
320
321
322
323
    def __init__(self):
        super().__init__()
        if current_platform.is_cuda_alike() or current_platform.is_cpu():
            self.op = torch.ops._C.gelu_quick
        elif current_platform.is_xpu():
            from vllm._ipex_ops import ipex_ops
            self.op = ipex_ops.gelu_quick

324
325
326
327
328
329
    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
        return x * torch.sigmoid(1.702 * x)

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
330
        self.op(out, x)
331
332
        return out

333
334
    def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(x)
335
        self.op(out, x)
336
337
        return out

338
339
340
    # TODO implement forward_xpu for QuickGELU
    # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:

341

342
@CustomOp.register("relu2")
343
344
345
346
347
348
349
class ReLUSquaredActivation(CustomOp):
    """
    Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
    """

    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
        """PyTorch-native implementation equivalent to forward()."""
350
        return torch.square(F.relu(x))
351
352
353
354
355

    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward_native(x)


356
357
358
359
360
361
362
363
364
class ScaledActivation(nn.Module):
    """An activation function with post-scale parameters.

    This is used for some quantization methods like AWQ.
    """

    def __init__(
        self,
        act_module: nn.Module,
365
366
367
        intermediate_size: int,
        input_is_parallel: bool = True,
        params_dtype: Optional[torch.dtype] = None,
368
369
370
    ):
        super().__init__()
        self.act = act_module
371
        self.input_is_parallel = input_is_parallel
372
373
374
375
376
377
378
379
        if input_is_parallel:
            tp_size = get_tensor_model_parallel_world_size()
            intermediate_size_per_partition = divide(intermediate_size,
                                                     tp_size)
        else:
            intermediate_size_per_partition = intermediate_size
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
380
        self.scales = nn.Parameter(
381
            torch.empty(intermediate_size_per_partition, dtype=params_dtype))
382
        set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
383

384
    def forward(self, x: torch.Tensor) -> torch.Tensor:
385
386
        return self.act(x) / self.scales

387
388
    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
        param_data = param.data
389
390
391
392
393
        if self.input_is_parallel:
            tp_rank = get_tensor_model_parallel_rank()
            shard_size = param_data.shape[0]
            start_idx = tp_rank * shard_size
            loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
394
395
396
        assert param_data.shape == loaded_weight.shape
        param_data.copy_(loaded_weight)

397

398
399
400
401
402
403
404
405
406
407
408
409
410
_ACTIVATION_REGISTRY = LazyDict({
    "gelu":
    lambda: nn.GELU(),
    "gelu_fast":
    lambda: FastGELU(),
    "gelu_new":
    lambda: NewGELU(),
    "gelu_pytorch_tanh":
    lambda: nn.GELU(approximate="tanh"),
    "relu":
    lambda: nn.ReLU(),
    "relu2":
    lambda: ReLUSquaredActivation(),
411
412
    "silu":
    lambda: nn.SiLU(),
413
414
415
    "quick_gelu":
    lambda: QuickGELU(),
})
416
417


418
def get_act_fn(act_fn_name: str) -> nn.Module:
419
    """Get an activation function by name."""
420
421
422
423
424
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

425
    return _ACTIVATION_REGISTRY[act_fn_name]
426
427
428
429
430


_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
    "gelu": lambda: GeluAndMul(),
    "silu": lambda: SiluAndMul(),
431
    "geglu": lambda: GeluAndMul(),
432
433
434
})


435
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
436
437
438
439
440
441
    """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
    act_fn_name = act_fn_name.lower()
    if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
        raise ValueError(
            f"Activation function {act_fn_name!r} is not supported.")

442
    return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]