Commit 63cefae8 authored by liumg's avatar liumg
Browse files

first modification for dcu

parent 9380e588
Pipeline #3189 canceled with stages
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union,List
from torch.nn.common_types import _size_2_t
from torch import Tensor
@torch.library.custom_op("lightop::conv_bias_add", mutates_args=())
def fuse_conv_bias_add(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
add: torch.Tensor,
padding: List[int],
stride: List[int],
dilation: List[int],
) -> torch.Tensor:
from lightop import miopen_conv_bias_add as conv_bias_add
return conv_bias_add(input, weight, bias, add, padding, stride, dilation)
@fuse_conv_bias_add.register_fake
def conv_bias_add_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
add: torch.Tensor,
padding: List[int],
stride: List[int],
dilation: List[int]
):
return torch.empty_like(add)
class ConvBiasAdd(torch.nn.Conv2d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t]= 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None, dtype=None):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype
)
def forward(self, input: torch.Tensor, add: torch.Tensor = None) -> torch.Tensor:
return fuse_conv_bias_add(input,
self.weight,
self.bias,
add,
self.padding, self.stride, self.dilation)
@torch.library.custom_op("lightop::conv_bias", mutates_args=())
def fuse_conv_bias(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
padding: List[int],
stride: List[int],
dilation: List[int],
) -> torch.Tensor:
from lightop import miopen_conv_bias as conv_bias
return conv_bias(input, weight, bias, padding, stride, dilation)
@fuse_conv_bias.register_fake
def conv_bias_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
padding: tuple,
stride: tuple,
dilation: tuple,
) -> torch.Tensor:
"""计算输出形状的元函数"""
# 确保输入维度正确
if input.dim() not in [4, 5]:
raise ValueError(f"Input tensor must be 4D or 5D, got {input.dim()}D")
# 统一参数格式
padding = tuple(padding) if isinstance(padding, list) else padding
stride = tuple(stride) if isinstance(stride, list) else stride
dilation = tuple(dilation) if isinstance(dilation, list) else dilation
# 计算输出高度
if input.dim() == 4: # 4D: [N, C, H, W]
h_in = input.size(2)
w_in = input.size(3)
kH = weight.size(2)
kW = weight.size(3)
else: # 5D: [N, C, D, H, W]
h_in = input.size(3)
w_in = input.size(4)
kH = weight.size(3)
kW = weight.size(4)
# 处理参数格式
padH, padW = padding if isinstance(padding, tuple) else (padding, padding)
strideH, strideW = stride if isinstance(stride, tuple) else (stride, stride)
dilationH, dilationW = dilation if isinstance(dilation, tuple) else (dilation, dilation)
# 计算输出形状 (标准卷积公式)
h_out = (h_in + 2 * padH - dilationH * (kH - 1) - 1) // strideH + 1
w_out = (w_in + 2 * padW - dilationW * (kW - 1) - 1) // strideW + 1
# 构造输出形状
if input.dim() == 4:
output_shape = (input.size(0), weight.size(0), h_out, w_out)
else:
output_shape = (input.size(0), weight.size(0), input.size(2), h_out, w_out)
# 创建与输入属性相同的元张量
memory_format = torch.channels_last
return torch.empty(
output_shape,
dtype=input.dtype,
device=input.device,
layout=input.layout,
requires_grad=input.requires_grad,
memory_format=memory_format
)
class ConvBias(torch.nn.Conv2d):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t]= 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None, dtype=None):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype
)
def forward(self, input):
return fuse_conv_bias(input,
self.weight,
self.bias,
self.padding,
self.stride,
self.dilation)
@torch.library.custom_op("lightop::miopenGroupNorm", mutates_args=())
def fuse_miopenGroupNorm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
num_groups: int,
epsilon: float,
mode: int,
) -> torch.Tensor:
#)-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from lightop import miopen_groupnorm as groupnorm
return groupnorm(x, weight, bias, num_groups, epsilon, mode)
@fuse_miopenGroupNorm.register_fake
def fuse_miopenGroupNorm_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
num_groups: int,
epsilon: float,
mode: int
) -> torch.Tensor:
#) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""计算输出形状的元函数"""
# 输出形状与输入相同
output_shape = x.shape
batch_size = x.size(0)
mean_rstd_len = [batch_size * num_groups, 1, 1, 1]
if x.dim() == 5:
mean_rstd_len.append(1)
# 创建输出张量
out_y = torch.empty_like(x)
memory_format = torch.channels_last
out_mean = torch.empty(
mean_rstd_len,
dtype=x.dtype,
device=x.device,
layout=x.layout,
memory_format=memory_format
)
out_rstd = torch.empty(
mean_rstd_len,
dtype=x.dtype,
device=x.device,
layout=x.layout,
memory_format=memory_format
)
return out_y #,out_mean,out_rstd
class miopenGroupNorm(torch.nn.Module):
# mode = 0 , MIOPEN_ELEMENTWISE_AFFINE
# mode = 1 , MIOPEN_WEIGHT_BIAS
# mode = 10 , MIOPEN_WEIGHT_BIAS_FUSION_SILU
# mode = 11 , MIOPEN_FUSION_SILU
def __init__(self, num_groups:int, num_channels:int, mode: int, eps: float = 1e-5, device=None, dtype=None):
super(miopenGroupNorm , self).__init__()
self.eps = eps
self.num_groups = num_groups
self.num_channels = num_channels
self.mode = mode
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = torch.nn.Parameter(torch.empty(num_channels, **factory_kwargs))
self.bias = torch.nn.Parameter(torch.empty(num_channels, **factory_kwargs))
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, x):
return fuse_miopenGroupNorm(x, self.weight, self.bias, self.num_groups, self.eps, self.mode)
def extra_repr(self):
return f'num_groups={self.num_groups},num_channels={self.num_channels},eps={round(self.eps,5):0.5f},mode={self.mode}'
# 定义自定义算子
@torch.library.custom_op("lightop::miopen_scaled_dot_product_attention", mutates_args=(),)
def fuse_miopen_scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_msk_: Optional[torch.Tensor] = None,
droprate: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
)->torch.Tensor:
from lightop import miopen_scaled_dot_product_attention
return miopen_scaled_dot_product_attention(query, key, value, attn_msk_, droprate, is_causal, scale, enable_gqa)
@fuse_miopen_scaled_dot_product_attention.register_fake
def _(query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_msk_: Optional[torch.Tensor] = None,
droprate: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
)->torch.Tensor:
B, H, S, D = query.shape
_, H_k, S_k, D_v = value.shape
# 验证输入维度
assert query.dim() == 4, "Query must be 4D [B, H, S, D]"
assert key.shape == (B, H_k, S_k, key.size(3)), "Key shape mismatch"
assert value.shape == (B, H_k, S_k, D_v), "Value shape mismatch"
return torch.empty(
(B, H, S, D_v),
dtype=query.dtype,
device=query.device,
)
\ No newline at end of file
......@@ -114,7 +114,9 @@ class GEGLU(nn.Module):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states = self.proj(hidden_states)
# hidden_states = self.proj(hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, self.proj.weight.data) + self.proj.bias.data
if is_torch_npu_available():
# using torch_npu.npu_geglu can run faster and save memory on NPU.
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
......
......@@ -1737,6 +1737,10 @@ class FeedForward(nn.Module):
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
# for module in self.net:
# hidden_states = module(hidden_states)
# DCU OPT: TN->NN
hidden_states = self.net[0](hidden_states)
hidden_states = self.net[1](hidden_states)
hidden_states = torch.matmul(hidden_states, self.net[2].weight.data) + self.net[2].bias.data
return hidden_states
......@@ -24,6 +24,7 @@ from ..utils import deprecate, is_torch_xla_available, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
from lightop import miopenGroupNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -2737,23 +2738,51 @@ class AttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
# query = attn.to_q(hidden_states)
# DCU OPT: TN->NN
if attn.to_q.bias:
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
# DCU OPT: TN->NN
if attn.to_k.bias:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(encoder_hidden_states, attn.to_k.weight.data)
if attn.to_v.bias:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(encoder_hidden_states, attn.to_v.weight.data)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# DCU OPT: TN->NN
if isinstance(attn.to_q.bias, torch.Tensor):
query = torch.matmul(hidden_states, attn.to_q.weight.data) + attn.to_q.bias.data
else:
query = torch.matmul(hidden_states, attn.to_q.weight.data)
if isinstance(attn.to_k.bias, torch.Tensor):
key = torch.matmul(hidden_states, attn.to_k.weight.data) + attn.to_k.bias.data
else:
key = torch.matmul(hidden_states, attn.to_k.weight.data)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if isinstance(attn.to_v.bias, torch.Tensor):
value = torch.matmul(hidden_states, attn.to_v.weight.data) + attn.to_v.bias.data
else:
value = torch.matmul(hidden_states, attn.to_v.weight.data)
if attn.norm_q is not None:
query = attn.norm_q(query)
......@@ -2770,7 +2799,9 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# hidden_states = attn.to_out[0](hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, attn.to_out[0].weight.data) + attn.to_out[0].bias.data
# dropout
hidden_states = attn.to_out[1](hidden_states)
......
......@@ -34,6 +34,8 @@ from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
from lightop import miopenConvBiasAdd as ConvBiasAdd
from lightop import miopenConvBias as ConvBias
class AutoencoderKL(
ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
......
......@@ -28,7 +28,10 @@ from ..unets.unet_2d_blocks import (
get_down_block,
get_up_block,
)
from torch.nn import FuseGroupNorm as GroupNorm
from lightop import miopenConvBiasAdd as ConvBiasAdd
from ...custom_op import ConvBias as ConvBias
from ...custom_op import miopenGroupNorm
@dataclass
class EncoderOutput(BaseOutput):
......@@ -216,7 +219,15 @@ class Decoder(nn.Module):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
# self.conv_in = nn.Conv2d(
# in_channels,
# block_out_channels[-1],
# kernel_size=3,
# stride=1,
# padding=1,
# )
# DCU OPT: conv_bias
self.conv_in = ConvBias(
in_channels,
block_out_channels[-1],
kernel_size=3,
......@@ -271,9 +282,13 @@ class Decoder(nn.Module):
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
# self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
# DCU OPT: gn_silu
self.conv_norm_out = miopenGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6,mode=10)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
# self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
# DCU OPT: conv_bias
self.conv_out = ConvBias(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
......@@ -303,10 +318,12 @@ class Decoder(nn.Module):
# post-process
if latent_embeds is None:
# sample = self.conv_norm_out(sample)
# DCU OPT: gn_silu
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
......
......@@ -21,7 +21,8 @@ import torch.nn.functional as F
from ..utils import deprecate
from .normalization import RMSNorm
from .upsampling import upfirdn2d_native
from lightop import miopenConvBiasAdd as ConvBiasAdd
from ..custom_op import ConvBias as ConvBias
class Downsample1D(nn.Module):
"""A 1D downsampling layer with an optional convolution.
......@@ -113,9 +114,11 @@ class Downsample2D(nn.Module):
raise ValueError(f"unknown norm_type: {norm_type}")
if use_conv:
conv = nn.Conv2d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
# conv = nn.Conv2d(
# self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
# )
# DCU OPT: conv_bias
conv = ConvBias(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
......
......@@ -1295,12 +1295,16 @@ class TimestepEmbedding(nn.Module):
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
# sample = self.linear_1(sample)
# DCU OPT: TN->NN
sample = torch.matmul(sample, self.linear_1.weight.data) + self.linear_1.bias.data
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
# sample = self.linear_2(sample)
# DCU OPT: TN->NN
sample = torch.matmul(sample, self.linear_2.weight.data) + self.linear_2.bias.data
if self.post_act is not None:
sample = self.post_act(sample)
......
......@@ -305,6 +305,49 @@ def load_model_dict_into_meta(
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
# DCU OPT: TN->NN
for param_name, param in model.named_parameters():
if 'weight' in param_name and 'add_embedding.linear_1' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("rzc test error")
if 'weight' in param_name and 'add_embedding.linear_2' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("rzc test error")
if 'weight' in param_name and 'ff.net' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'time_emb_proj' in param_name:
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'attn' in param_name and ('to_q' in param_name or 'to_k' in param_name or 'to_v' in param_name or 'to_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
if 'weight' in param_name and 'time_embedding' in param_name and ('linear_1' in param_name or 'linear_2' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("transpose weight to NN error")
if 'weight' in param_name and 'attentions' in param_name and ('proj_in' in param_name or 'proj_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("transpose weight to NN error")
if 'weight' in param_name and 'decoder.mid_block.attentions.0' in param_name and ('to_q' in param_name or 'to_k' in param_name or 'to_v' in param_name or 'to_out' in param_name):
if param.data.dim() == 2:
param.data = param.data.permute(1, 0).contiguous()
else:
raise ValueError("lijian test error")
return offload_index, state_dict_index
......
......@@ -1730,6 +1730,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
" able to use it for predictions and inference."
)
# DCU OPT: NHWC
model = model.to(memory_format=torch.channels_last)
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
......
......@@ -14,7 +14,7 @@
# limitations under the License.
from functools import partial
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
import torch.nn as nn
......@@ -40,6 +40,10 @@ from .upsampling import ( # noqa
upsample_2d,
)
from torch.nn import FuseGroupNorm as GroupNorm
from ..custom_op import ConvBiasAdd as ConvBiasAdd
from ..custom_op import ConvBias as ConvBias
from ..custom_op import miopenGroupNorm
class ResnetBlockCondNorm2D(nn.Module):
r"""
......@@ -264,9 +268,13 @@ class ResnetBlock2D(nn.Module):
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
# self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
# DCU OPT: gn_silu
self.norm1 = miopenGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps,mode=10)
# DCU OPT: conv_bias
self.conv1 = ConvBias(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
......@@ -278,11 +286,18 @@ class ResnetBlock2D(nn.Module):
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
# self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
# DCU OPT: gn_silu
if self.time_embedding_norm == "scale_shift" or self.time_embedding_norm == "default":
self.norm2 = miopenGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps,mode=1)
else:
self.norm2 = miopenGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps,mode=10)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
# self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
# DCU OPT: conv_bias_add
self.conv2 = ConvBiasAdd(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
......@@ -308,7 +323,16 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = nn.Conv2d(
# self.conv_shortcut = nn.Conv2d(
# in_channels,
# conv_2d_out_channels,
# kernel_size=1,
# stride=1,
# padding=0,
# bias=conv_shortcut_bias,
# )
# DCU OPT: conv_bias_add
self.conv_shortcut = ConvBiasAdd(
in_channels,
conv_2d_out_channels,
kernel_size=1,
......@@ -324,8 +348,11 @@ class ResnetBlock2D(nn.Module):
hidden_states = input_tensor
# hidden_states = self.norm1(hidden_states)
# hidden_states = self.nonlinearity(hidden_states)
# DCU OPT: gn_silu
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
......@@ -343,32 +370,54 @@ class ResnetBlock2D(nn.Module):
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
#temb = self.time_emb_proj(temb)[:, :, None, None]
# DCU OPT: TN->NN
x = torch.matmul(temb, self.time_emb_proj.weight.data) + self.time_emb_proj.bias.data
temb = x[:,:,None,None]
if self.time_embedding_norm == "default":
if temb is not None:
hidden_states = hidden_states + temb
# DCU OPT: ori
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
elif self.time_embedding_norm == "scale_shift":
if temb is None:
raise ValueError(
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
)
time_scale, time_shift = torch.chunk(temb, 2, dim=1)
# DCU OPT: ori
hidden_states = self.norm2(hidden_states)
hidden_states = hidden_states * (1 + time_scale) + time_shift
hidden_states = self.nonlinearity(hidden_states)
else:
# DCU OPT: gn_silu
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
# hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
# origin
"""
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
"""
# DCU OPT: conv_bias_add
if self.conv_shortcut is not None:
tmp_add=torch.zeros_like(hidden_states, memory_format=torch.channels_last)
hidden_states = self.conv2(hidden_states,tmp_add)
input_tensor = self.conv_shortcut(input_tensor.contiguous(memory_format=torch.channels_last),hidden_states)
output_tensor = input_tensor / self.output_scale_factor
else:
hidden_states = self.conv2(hidden_states,input_tensor)
output_tensor = hidden_states / self.output_scale_factor
return output_tensor
......
......@@ -25,6 +25,7 @@ from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import LegacyModelMixin
from ..normalization import AdaLayerNormSingle
from ...custom_op import miopenGroupNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -173,9 +174,12 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
self._init_patched_inputs(norm_type=norm_type)
def _init_continuous_input(self, norm_type):
self.norm = torch.nn.GroupNorm(
num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
)
# self.norm = torch.nn.GroupNorm(
# num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
# )
# DCU OPT: gn
self.norm = miopenGroupNorm(num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6,mode=1)
if self.use_linear_projection:
self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
else:
......@@ -463,16 +467,21 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
def _operate_on_continuous_inputs(self, hidden_states):
batch, _, height, width = hidden_states.shape
# DCU OPT: gn
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
# hidden_states = self.proj_in(hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, self.proj_in.weight.data) + self.proj_in.bias.data
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
# hidden_states = self.proj_in(hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, self.proj_in.weight.data) + self.proj_in.bias.data
return hidden_states, inner_dim
......@@ -498,14 +507,26 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
if not self.use_linear_projection:
# hidden_states = (
# hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
# )
# DCU OPT: NHWC
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous(memory_format=torch.channels_last)
)
hidden_states = self.proj_out(hidden_states)
# hidden_states = self.proj_out(hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, self.proj_out.weight.data) + self.proj_out.bias.data
else:
hidden_states = self.proj_out(hidden_states)
# hidden_states = self.proj_out(hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, self.proj_out.weight.data) + self.proj_out.bias.data
# hidden_states = (
# hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
# )
# DCU OPT: NHWC
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous(memory_format=torch.channels_last)
)
output = hidden_states + residual
......@@ -535,7 +556,9 @@ class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
# hidden_states = self.proj_out(hidden_states)
# DCU OPT: TN->NN
hidden_states = torch.matmul(hidden_states, self.proj_out.weight.data) + self.proj_out.bias.data
hidden_states = hidden_states.squeeze(1)
# unpatchify
......
......@@ -51,6 +51,10 @@ from .unet_2d_blocks import (
)
from lightop import miopenConvBiasAdd as ConvBiasAdd
from ...custom_op import ConvBias as ConvBias
from ...custom_op import miopenGroupNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -263,9 +267,11 @@ class UNet2DConditionModel(
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
# self.conv_in = nn.Conv2d(
# in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
# )
# DCU OPT: conv_bias
self.conv_in = ConvBias(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
# time
time_embed_dim, timestep_input_dim = self._set_time_proj(
......@@ -472,9 +478,11 @@ class UNet2DConditionModel(
# out
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
# self.conv_norm_out = nn.GroupNorm(
# num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
# )
# DCU OPT: gn_silu
self.conv_norm_out =miopenGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps,mode=10)
self.conv_act = get_activation(act_fn)
......@@ -483,9 +491,11 @@ class UNet2DConditionModel(
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
# self.conv_out = nn.Conv2d(
# block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
# )
# DCU OPT: conv_bias
self.conv_out = ConvBias(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
......@@ -1235,8 +1245,10 @@ class UNet2DConditionModel(
# 6. post-process
if self.conv_norm_out:
# sample = self.conv_norm_out(sample)
# sample = self.conv_act(sample)
# DCU OPT: gn_silu
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
......
......@@ -21,7 +21,8 @@ import torch.nn.functional as F
from ..utils import deprecate
from ..utils.import_utils import is_torch_version
from .normalization import RMSNorm
from lightop import miopenConvBiasAdd as ConvBiasAdd
from ..custom_op import ConvBias as ConvBias
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
......@@ -131,7 +132,9 @@ class Upsample2D(nn.Module):
elif use_conv:
if kernel_size is None:
kernel_size = 3
conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
# conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
# DCU OPT: conv_bias
conv = ConvBias(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
......
......@@ -1023,7 +1023,8 @@ class StableDiffusionPipeline(
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# DCU OPT: NHWC
latents=latents.contiguous(memory_format=torch.channels_last)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
......
......@@ -279,6 +279,9 @@ class StableDiffusionXLPipeline(
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
# DCU OPT: NHWC
self.vae.to(memory_format=torch.channels_last)
self.unet.to(memory_format=torch.channels_last)
def encode_prompt(
self,
......@@ -1110,6 +1113,8 @@ class StableDiffusionXLPipeline(
generator,
latents,
)
# DCU OPT: NHWC
latents = latents.to(memory_format=torch.channels_last)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment