import torch.cuda from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D, USE_PEFT_BACKEND from torch import distributed as dist from torch import nn from torch.nn import functional as F from ..base_module import BaseModule from ...utils import DistriConfig class DistriResnetBlock2DTP(BaseModule): def __init__(self, module: ResnetBlock2D, distri_config: DistriConfig): super(DistriResnetBlock2DTP, self).__init__(module, distri_config) assert module.conv1.out_channels % distri_config.n_device_per_batch == 0 mid_channels = module.conv1.out_channels // distri_config.n_device_per_batch sharded_conv1 = nn.Conv2d( module.conv1.in_channels, mid_channels, module.conv1.kernel_size, module.conv1.stride, module.conv1.padding, module.conv1.dilation, module.conv1.groups, module.conv1.bias is not None, module.conv1.padding_mode, device=module.conv1.weight.device, dtype=module.conv1.weight.dtype, ) sharded_conv1.weight.data.copy_( module.conv1.weight.data[ distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) if module.conv1.bias is not None: sharded_conv1.bias.data.copy_( module.conv1.bias.data[ distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) sharded_conv2 = nn.Conv2d( mid_channels, module.conv2.out_channels, module.conv2.kernel_size, module.conv2.stride, module.conv2.padding, module.conv2.dilation, module.conv2.groups, module.conv2.bias is not None, module.conv2.padding_mode, device=module.conv2.weight.device, dtype=module.conv2.weight.dtype, ) sharded_conv2.weight.data.copy_( module.conv2.weight.data[ :, distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) if module.conv2.bias is not None: sharded_conv2.bias.data.copy_(module.conv2.bias.data) assert module.time_emb_proj is not None assert module.time_embedding_norm == "default" sharded_time_emb_proj = nn.Linear( module.time_emb_proj.in_features, mid_channels, bias=module.time_emb_proj.bias is not None, device=module.time_emb_proj.weight.device, dtype=module.time_emb_proj.weight.dtype, ) sharded_time_emb_proj.weight.data.copy_( module.time_emb_proj.weight.data[ distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) if module.time_emb_proj.bias is not None: sharded_time_emb_proj.bias.data.copy_( module.time_emb_proj.bias.data[ distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) sharded_norm2 = nn.GroupNorm( module.norm2.num_groups // distri_config.n_device_per_batch, mid_channels, module.norm2.eps, module.norm2.affine, device=module.norm2.weight.device, dtype=module.norm2.weight.dtype, ) if module.norm2.affine: sharded_norm2.weight.data.copy_( module.norm2.weight.data[ distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) sharded_norm2.bias.data.copy_( module.norm2.bias.data[ distri_config.split_idx() * mid_channels : (distri_config.split_idx() + 1) * mid_channels ] ) del module.conv1 del module.conv2 del module.time_emb_proj del module.norm2 module.conv1 = sharded_conv1 module.conv2 = sharded_conv2 module.time_emb_proj = sharded_time_emb_proj module.norm2 = sharded_norm2 torch.cuda.empty_cache() def forward( self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0, ) -> torch.FloatTensor: assert scale == 1.0 distri_config = self.distri_config module = self.module hidden_states = input_tensor hidden_states = module.norm1(hidden_states) hidden_states = module.nonlinearity(hidden_states) if module.upsample is not None: # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = ( module.upsample(input_tensor, scale=scale) if isinstance(module.upsample, Upsample2D) else module.upsample(input_tensor) ) hidden_states = ( module.upsample(hidden_states, scale=scale) if isinstance(module.upsample, Upsample2D) else module.upsample(hidden_states) ) elif module.downsample is not None: input_tensor = ( module.downsample(input_tensor, scale=scale) if isinstance(module.downsample, Downsample2D) else module.downsample(input_tensor) ) hidden_states = ( module.downsample(hidden_states, scale=scale) if isinstance(module.downsample, Downsample2D) else module.downsample(hidden_states) ) hidden_states = module.conv1(hidden_states) if module.time_emb_proj is not None: if not module.skip_time_act: temb = module.nonlinearity(temb) temb = module.time_emb_proj(temb)[:, :, None, None] if temb is not None and module.time_embedding_norm == "default": hidden_states = hidden_states + temb hidden_states = module.norm2(hidden_states) if temb is not None and module.time_embedding_norm == "scale_shift": scale, shift = torch.chunk(temb, 2, dim=1) hidden_states = hidden_states * (1 + scale) + shift hidden_states = module.nonlinearity(hidden_states) hidden_states = module.dropout(hidden_states) hidden_states = F.conv2d( hidden_states, module.conv2.weight, bias=None, stride=module.conv2.stride, padding=module.conv2.padding, dilation=module.conv2.dilation, groups=module.conv2.groups, ) dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False) if module.conv2.bias is not None: hidden_states = hidden_states + module.conv2.bias.view(1, -1, 1, 1) if module.conv_shortcut is not None: input_tensor = ( module.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor) ) output_tensor = (input_tensor + hidden_states) / module.output_scale_factor self.counter += 1 return output_tensor