import torch.cuda from diffusers.models.attention_processor import Attention from torch import distributed as dist from torch import nn from torch.nn import functional as F from ..base_module import BaseModule from ...utils import DistriConfig class DistriAttentionTP(BaseModule): def __init__(self, module: Attention, distri_config: DistriConfig): super(DistriAttentionTP, self).__init__(module, distri_config) heads = module.heads sliced_heads = heads // distri_config.n_device_per_batch remainder_heads = heads % distri_config.n_device_per_batch if distri_config.split_idx() < remainder_heads: sliced_heads += 1 self.sliced_heads = sliced_heads if sliced_heads > 0: if distri_config.split_idx() < remainder_heads: start_head = distri_config.split_idx() * sliced_heads else: start_head = ( remainder_heads * (sliced_heads + 1) + (distri_config.split_idx() - remainder_heads) * sliced_heads ) end_head = start_head + sliced_heads dim = module.to_q.out_features // heads sharded_to_q = nn.Linear( module.to_q.in_features, sliced_heads * dim, bias=module.to_q.bias is not None, device=module.to_q.weight.device, dtype=module.to_q.weight.dtype, ) sharded_to_q.weight.data.copy_(module.to_q.weight.data[start_head * dim : end_head * dim]) if module.to_q.bias is not None: sharded_to_q.bias.data.copy_(module.to_q.bias.data[start_head * dim : end_head * dim]) sharded_to_k = nn.Linear( module.to_k.in_features, sliced_heads * dim, bias=module.to_k.bias is not None, device=module.to_k.weight.device, dtype=module.to_k.weight.dtype, ) sharded_to_k.weight.data.copy_(module.to_k.weight.data[start_head * dim : end_head * dim]) if module.to_k.bias is not None: sharded_to_k.bias.data.copy_(module.to_k.bias.data[start_head * dim : end_head * dim]) sharded_to_v = nn.Linear( module.to_v.in_features, sliced_heads * dim, bias=module.to_v.bias is not None, device=module.to_v.weight.device, dtype=module.to_v.weight.dtype, ) sharded_to_v.weight.data.copy_(module.to_v.weight.data[start_head * dim : end_head * dim]) if module.to_v.bias is not None: sharded_to_v.bias.data.copy_(module.to_v.bias.data[start_head * dim : end_head * dim]) sharded_to_out = nn.Linear( sliced_heads * dim, module.to_out[0].out_features, bias=module.to_out[0].bias is not None, device=module.to_out[0].weight.device, dtype=module.to_out[0].weight.dtype, ) sharded_to_out.weight.data.copy_(module.to_out[0].weight.data[:, start_head * dim : end_head * dim]) if module.to_out[0].bias is not None: sharded_to_out.bias.data.copy_(module.to_out[0].bias.data) del module.to_q del module.to_k del module.to_v old_to_out = module.to_out[0] module.to_q = sharded_to_q module.to_k = sharded_to_k module.to_v = sharded_to_v module.to_out[0] = sharded_to_out module.heads = sliced_heads del old_to_out torch.cuda.empty_cache() def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor or None = None, attention_mask: torch.FloatTensor or None = None, **cross_attention_kwargs, ) -> torch.Tensor: distri_config = self.distri_config module = self.module residual = hidden_states if self.sliced_heads > 0: input_ndim = hidden_states.ndim assert input_ndim == 3 batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = module.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, module.heads, -1, attention_mask.shape[-1]) if module.group_norm is not None: hidden_states = module.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = module.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif module.norm_cross: encoder_hidden_states = module.norm_encoder_hidden_states(encoder_hidden_states) key = module.to_k(encoder_hidden_states) value = module.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // module.heads query = query.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, module.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, module.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = F.linear(hidden_states, module.to_out[0].weight, bias=None) # dropout hidden_states = module.to_out[1](hidden_states) else: hidden_states = torch.zeros( [hidden_states.shape[0], hidden_states.shape[1], module.to_out[0].out_features], device=hidden_states.device, dtype=hidden_states.dtype, ) dist.all_reduce(hidden_states, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False) if module.to_out[0].bias is not None: hidden_states = hidden_states + module.to_out[0].bias.view(1, 1, -1) if module.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / module.rescale_output_factor self.counter += 1 return hidden_states