import torch import torch.nn as nn from typing import Optional, Tuple, Union,List from torch.nn.common_types import _size_2_t from torch import Tensor @torch.library.custom_op("lightop::conv_bias_add", mutates_args=()) def fuse_conv_bias_add( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, add: torch.Tensor, padding: List[int], stride: List[int], dilation: List[int], ) -> torch.Tensor: from lightop import miopen_conv_bias_add as conv_bias_add return conv_bias_add(input, weight, bias, add, padding, stride, dilation) @fuse_conv_bias_add.register_fake def conv_bias_add_fake( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, add: torch.Tensor, padding: List[int], stride: List[int], dilation: List[int] ): return torch.empty_like(add) class ConvBiasAdd(torch.nn.Conv2d): def __init__(self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t]= 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype ) def forward(self, input: torch.Tensor, add: torch.Tensor = None) -> torch.Tensor: return fuse_conv_bias_add(input, self.weight, self.bias, add, self.padding, self.stride, self.dilation) @torch.library.custom_op("lightop::conv_bias", mutates_args=()) def fuse_conv_bias( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, padding: List[int], stride: List[int], dilation: List[int], ) -> torch.Tensor: from lightop import miopen_conv_bias as conv_bias return conv_bias(input, weight, bias, padding, stride, dilation) @fuse_conv_bias.register_fake def conv_bias_fake( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, padding: tuple, stride: tuple, dilation: tuple, ) -> torch.Tensor: """计算输出形状的元函数""" # 确保输入维度正确 if input.dim() not in [4, 5]: raise ValueError(f"Input tensor must be 4D or 5D, got {input.dim()}D") # 统一参数格式 padding = tuple(padding) if isinstance(padding, list) else padding stride = tuple(stride) if isinstance(stride, list) else stride dilation = tuple(dilation) if isinstance(dilation, list) else dilation # 计算输出高度 if input.dim() == 4: # 4D: [N, C, H, W] h_in = input.size(2) w_in = input.size(3) kH = weight.size(2) kW = weight.size(3) else: # 5D: [N, C, D, H, W] h_in = input.size(3) w_in = input.size(4) kH = weight.size(3) kW = weight.size(4) # 处理参数格式 padH, padW = padding if isinstance(padding, tuple) else (padding, padding) strideH, strideW = stride if isinstance(stride, tuple) else (stride, stride) dilationH, dilationW = dilation if isinstance(dilation, tuple) else (dilation, dilation) # 计算输出形状 (标准卷积公式) h_out = (h_in + 2 * padH - dilationH * (kH - 1) - 1) // strideH + 1 w_out = (w_in + 2 * padW - dilationW * (kW - 1) - 1) // strideW + 1 # 构造输出形状 if input.dim() == 4: output_shape = (input.size(0), weight.size(0), h_out, w_out) else: output_shape = (input.size(0), weight.size(0), input.size(2), h_out, w_out) # 创建与输入属性相同的元张量 memory_format = torch.channels_last return torch.empty( output_shape, dtype=input.dtype, device=input.device, layout=input.layout, requires_grad=input.requires_grad, memory_format=memory_format ) class ConvBias(torch.nn.Conv2d): def __init__(self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t]= 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", device=None, dtype=None): super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype ) def forward(self, input): return fuse_conv_bias(input, self.weight, self.bias, self.padding, self.stride, self.dilation) @torch.library.custom_op("lightop::miopenGroupNorm", mutates_args=()) def fuse_miopenGroupNorm( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, num_groups: int, epsilon: float, mode: int, ) -> torch.Tensor: #)-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from lightop import miopen_groupnorm as groupnorm return groupnorm(x, weight, bias, num_groups, epsilon, mode) @fuse_miopenGroupNorm.register_fake def fuse_miopenGroupNorm_fake( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, num_groups: int, epsilon: float, mode: int ) -> torch.Tensor: #) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """计算输出形状的元函数""" # 输出形状与输入相同 output_shape = x.shape batch_size = x.size(0) mean_rstd_len = [batch_size * num_groups, 1, 1, 1] if x.dim() == 5: mean_rstd_len.append(1) # 创建输出张量 out_y = torch.empty_like(x) memory_format = torch.channels_last out_mean = torch.empty( mean_rstd_len, dtype=x.dtype, device=x.device, layout=x.layout, memory_format=memory_format ) out_rstd = torch.empty( mean_rstd_len, dtype=x.dtype, device=x.device, layout=x.layout, memory_format=memory_format ) return out_y #,out_mean,out_rstd class miopenGroupNorm(torch.nn.Module): # mode = 0 , MIOPEN_ELEMENTWISE_AFFINE # mode = 1 , MIOPEN_WEIGHT_BIAS # mode = 10 , MIOPEN_WEIGHT_BIAS_FUSION_SILU # mode = 11 , MIOPEN_FUSION_SILU def __init__(self, num_groups:int, num_channels:int, mode: int, eps: float = 1e-5, device=None, dtype=None): super(miopenGroupNorm , self).__init__() self.eps = eps self.num_groups = num_groups self.num_channels = num_channels self.mode = mode factory_kwargs = {'device': device, 'dtype': dtype} self.weight = torch.nn.Parameter(torch.empty(num_channels, **factory_kwargs)) self.bias = torch.nn.Parameter(torch.empty(num_channels, **factory_kwargs)) torch.nn.init.ones_(self.weight) torch.nn.init.zeros_(self.bias) def forward(self, x): return fuse_miopenGroupNorm(x, self.weight, self.bias, self.num_groups, self.eps, self.mode) def extra_repr(self): return f'num_groups={self.num_groups},num_channels={self.num_channels},eps={round(self.eps,5):0.5f},mode={self.mode}' # 定义自定义算子 @torch.library.custom_op("lightop::miopen_scaled_dot_product_attention", mutates_args=(),) def fuse_miopen_scaled_dot_product_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_msk_: Optional[torch.Tensor] = None, droprate: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, )->torch.Tensor: from lightop import miopen_scaled_dot_product_attention return miopen_scaled_dot_product_attention(query, key, value, attn_msk_, droprate, is_causal, scale, enable_gqa) @fuse_miopen_scaled_dot_product_attention.register_fake def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_msk_: Optional[torch.Tensor] = None, droprate: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, )->torch.Tensor: B, H, S, D = query.shape _, H_k, S_k, D_v = value.shape # 验证输入维度 assert query.dim() == 4, "Query must be 4D [B, H, S, D]" assert key.shape == (B, H_k, S_k, key.size(3)), "Key shape mismatch" assert value.shape == (B, H_k, S_k, D_v), "Value shape mismatch" return torch.empty( (B, H, S, D_v), dtype=query.dtype, device=query.device, )