rmsnorm.h 3.03 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#pragma once
/*
 
 * Copyright (c) 2024, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <torch/extension.h>

void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon);

void fused_add_rms_norm(torch::Tensor& input,
                        torch::Tensor& residual,
                        torch::Tensor& weight,
                        double epsilon);

// ck
torch::Tensor
rmsnorm2d(torch::Tensor& input,
          torch::Tensor& weight,
          double epsilon); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_add(
    torch::Tensor& out,          // [m ,n]
    torch::Tensor& input,        // [m ,n]
    torch::Tensor& residual_in,  // [m ,n]
    torch::Tensor& residual_out, // [m ,n]
    torch::Tensor& weight,       // [1 ,n]
    double epsilon); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_smoothquant(
    torch::Tensor& out,    // [m ,n]
    torch::Tensor& input,  // [m ,n]
    torch::Tensor& xscale, // [1 ,n]
    torch::Tensor& yscale, // [m ,1]
    torch::Tensor& weight, // [1 ,n]
    double epsilon); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_add_smoothquant(
    torch::Tensor& out,          // [m ,n]
    torch::Tensor& input,        // [m ,n]
    torch::Tensor& residual_in,  // [m ,n]
    torch::Tensor& residual_out, // [m ,n]
    torch::Tensor& xscale,       // [1 ,n]
    torch::Tensor& yscale,       // [m ,1]
    torch::Tensor& weight,       // [1 ,n]
    double epsilon,
    std::optional<torch::Tensor> out_before_quant); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_dynamicquant(
    torch::Tensor& out,    // [m ,n]
    torch::Tensor& input,  // [m ,n]
    torch::Tensor& yscale, // [m ,1]
    torch::Tensor& weight, // [1 ,n]
    double epsilon); // 0: Use default RMSNorm; 1: Use T5-like implementation

void rmsnorm2d_with_add_dynamicquant(
    torch::Tensor& out,          // [m ,n]
    torch::Tensor& input,        // [m ,n]
    torch::Tensor& residual_in,  // [m ,n]
    torch::Tensor& residual_out, // [m ,n]
    torch::Tensor& yscale,       // [m ,1]
    torch::Tensor& weight,       // [1 ,n]
    double epsilon); // 0: Use default RMSNorm; 1: Use T5-like implementation
75
76
77
78
79
80
81
    
// head rms norm: normalize each head independently
torch::Tensor head_rms_norm(
    torch::Tensor& input,        // [num_tokens, num_heads * head_dim]
    torch::Tensor& weight,       // [num_heads * head_dim]
    double epsilon,
    int64_t norm_head_dim);      // head_dim