custom_op.py 8.89 KB
Newer Older
renzhc's avatar
renzhc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union,List
from torch.nn.common_types import _size_2_t
from torch import Tensor

@torch.library.custom_op("lightop::conv_bias_add", mutates_args=())
def fuse_conv_bias_add(
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    add: torch.Tensor,
    padding: List[int],
    stride: List[int],
    dilation: List[int],
) -> torch.Tensor:
    from lightop import miopen_conv_bias_add as conv_bias_add
    return conv_bias_add(input, weight, bias, add, padding, stride, dilation)


@fuse_conv_bias_add.register_fake
def conv_bias_add_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    add: torch.Tensor,
    padding: List[int],
    stride: List[int],
    dilation: List[int]
):
    return torch.empty_like(add)


class ConvBiasAdd(torch.nn.Conv2d):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: _size_2_t,
                 stride: _size_2_t = 1,
                 padding: Union[str, _size_2_t]= 0,
                 dilation: _size_2_t = 1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = "zeros",
                 device=None, dtype=None):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype
        )

    def forward(self, input: torch.Tensor, add: torch.Tensor = None) -> torch.Tensor:
        return fuse_conv_bias_add(input,
                                    self.weight,
                                    self.bias,
                                    add,
                                    self.padding, self.stride, self.dilation)



@torch.library.custom_op("lightop::conv_bias", mutates_args=())
def fuse_conv_bias(
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    padding: List[int],
    stride: List[int],
    dilation: List[int],
) -> torch.Tensor:
    from lightop import miopen_conv_bias as conv_bias
    return conv_bias(input, weight, bias, padding, stride, dilation)

@fuse_conv_bias.register_fake
def conv_bias_fake(
    input: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    padding: tuple,
    stride: tuple,
    dilation: tuple,
) -> torch.Tensor:
    """计算输出形状的元函数"""
    # 确保输入维度正确
    if input.dim() not in [4, 5]:
        raise ValueError(f"Input tensor must be 4D or 5D, got {input.dim()}D")

    # 统一参数格式
    padding = tuple(padding) if isinstance(padding, list) else padding
    stride = tuple(stride) if isinstance(stride, list) else stride
    dilation = tuple(dilation) if isinstance(dilation, list) else dilation

    # 计算输出高度
    if input.dim() == 4:  # 4D: [N, C, H, W]
        h_in = input.size(2)
        w_in = input.size(3)
        kH = weight.size(2)
        kW = weight.size(3)
    else:  # 5D: [N, C, D, H, W]
        h_in = input.size(3)
        w_in = input.size(4)
        kH = weight.size(3)
        kW = weight.size(4)

    # 处理参数格式
    padH, padW = padding if isinstance(padding, tuple) else (padding, padding)
    strideH, strideW = stride if isinstance(stride, tuple) else (stride, stride)
    dilationH, dilationW = dilation if isinstance(dilation, tuple) else (dilation, dilation)

    # 计算输出形状 (标准卷积公式)
    h_out = (h_in + 2 * padH - dilationH * (kH - 1) - 1) // strideH + 1
    w_out = (w_in + 2 * padW - dilationW * (kW - 1) - 1) // strideW + 1

    # 构造输出形状
    if input.dim() == 4:
        output_shape = (input.size(0), weight.size(0), h_out, w_out)
    else:
        output_shape = (input.size(0), weight.size(0), input.size(2), h_out, w_out)

    # 创建与输入属性相同的元张量
    memory_format = torch.channels_last
    return torch.empty(
        output_shape,
        dtype=input.dtype,
        device=input.device,
        layout=input.layout,
        requires_grad=input.requires_grad,
        memory_format=memory_format
    )

class ConvBias(torch.nn.Conv2d):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: _size_2_t,
                 stride: _size_2_t = 1,
                 padding: Union[str, _size_2_t]= 0,
                 dilation: _size_2_t = 1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = "zeros",
                 device=None, dtype=None):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            bias,
            padding_mode,
            device,
            dtype
        )

    def forward(self, input):
        return fuse_conv_bias(input,
                              self.weight,
                              self.bias,
                              self.padding,
                              self.stride,
                              self.dilation)


@torch.library.custom_op("lightop::miopenGroupNorm", mutates_args=())
def fuse_miopenGroupNorm(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    num_groups: int,
    epsilon: float,
    mode: int,
    ) -> torch.Tensor:
    #)-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    from lightop import miopen_groupnorm as groupnorm
    return groupnorm(x, weight, bias, num_groups, epsilon, mode)

@fuse_miopenGroupNorm.register_fake
def fuse_miopenGroupNorm_fake(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    num_groups: int,
    epsilon: float,
    mode: int
) -> torch.Tensor:
#) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """计算输出形状的元函数"""
    # 输出形状与输入相同
    output_shape = x.shape
    batch_size = x.size(0)
    mean_rstd_len = [batch_size * num_groups, 1, 1, 1]

    if x.dim() == 5:
        mean_rstd_len.append(1)

    # 创建输出张量
    out_y = torch.empty_like(x)


    memory_format = torch.channels_last

    out_mean = torch.empty(
            mean_rstd_len,
            dtype=x.dtype,
            device=x.device,
            layout=x.layout,
            memory_format=memory_format
        )

    out_rstd = torch.empty(
        mean_rstd_len,
        dtype=x.dtype,
        device=x.device,
        layout=x.layout,
        memory_format=memory_format
    )

    return out_y #,out_mean,out_rstd

class miopenGroupNorm(torch.nn.Module):
    # mode = 0 , MIOPEN_ELEMENTWISE_AFFINE
    # mode = 1 , MIOPEN_WEIGHT_BIAS
    # mode = 10 , MIOPEN_WEIGHT_BIAS_FUSION_SILU
    # mode = 11 ,  MIOPEN_FUSION_SILU
    def __init__(self, num_groups:int, num_channels:int, mode: int, eps: float = 1e-5, device=None, dtype=None):
        super(miopenGroupNorm , self).__init__()
        self.eps = eps
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.mode = mode
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.weight = torch.nn.Parameter(torch.empty(num_channels, **factory_kwargs))
        self.bias = torch.nn.Parameter(torch.empty(num_channels, **factory_kwargs))
        torch.nn.init.ones_(self.weight)
        torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        return fuse_miopenGroupNorm(x, self.weight, self.bias, self.num_groups, self.eps, self.mode)

    def extra_repr(self):
        return f'num_groups={self.num_groups},num_channels={self.num_channels},eps={round(self.eps,5):0.5f},mode={self.mode}'




# 定义自定义算子
@torch.library.custom_op("lightop::miopen_scaled_dot_product_attention", mutates_args=(),)
def fuse_miopen_scaled_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_msk_: Optional[torch.Tensor] = None,
    droprate: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
    enable_gqa: bool = False,
)->torch.Tensor:

    from lightop import miopen_scaled_dot_product_attention
    return miopen_scaled_dot_product_attention(query, key, value, attn_msk_, droprate, is_causal, scale, enable_gqa)

@fuse_miopen_scaled_dot_product_attention.register_fake
def _(query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_msk_: Optional[torch.Tensor] = None,
    droprate: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
    enable_gqa: bool = False,
)->torch.Tensor:
    B, H, S, D = query.shape
    _, H_k, S_k, D_v = value.shape

    # 验证输入维度
    assert query.dim() == 4, "Query must be 4D [B, H, S, D]"
    assert key.shape == (B, H_k, S_k, key.size(3)), "Key shape mismatch"
    assert value.shape == (B, H_k, S_k, D_v), "Value shape mismatch"

    return torch.empty(
        (B, H, S, D_v),
        dtype=query.dtype,
        device=query.device,
    )