Commit 346d2571 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1802 failed with stages
in 0 seconds
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from src.pipelines.pipeline_superres_sdxl import StableDiffusionXLSuperResPipeline\n",
"from diffusers import AutoPipelineForText2Image\n",
"import torch\n",
"\n",
"from src.linfusion import LinFusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_ckpt = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
"device = torch.device('cuda')\n",
"\n",
"pipe = AutoPipelineForText2Image.from_pretrained(\n",
" model_ckpt, torch_dtype=torch.float16, variant=\"fp16\"\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prompt = \"An astronaut floating in space. Beautiful view of the stars and the universe in the background.\"\n",
"generator = torch.manual_seed(123)\n",
"image = pipe(\n",
" prompt, generator=generator\n",
").images[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pipe = StableDiffusionXLSuperResPipeline.from_pretrained(\n",
" model_ckpt, torch_dtype=torch.float16, variant=\"fp16\"\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"linfusion = LinFusion.construct_for(pipe)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"generator = torch.manual_seed(123)\n",
"pipe.enable_vae_tiling()\n",
"image = pipe(image=image, prompt=prompt,\n",
" height=2048, width=2048, device=device, \n",
" num_inference_steps=50, guidance_scale=7.5,\n",
" cosine_scale_1=3, cosine_scale_2=1, cosine_scale_3=1, gaussian_sigma=0.8,\n",
" generator=generator, upscale_strength=0.32).images[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from src.pipelines.pipeline_superres_sdxl import StableDiffusionXLSuperResPipeline\n",
"from diffusers import AutoPipelineForText2Image\n",
"import torch\n",
"\n",
"from src.tools import (\n",
" forward_unet_wrapper, \n",
" forward_resnet_wrapper, \n",
" forward_crossattndownblock2d_wrapper, \n",
" forward_crossattnupblock2d_wrapper,\n",
" forward_downblock2d_wrapper, \n",
" forward_upblock2d_wrapper,\n",
" forward_transformer_block_wrapper)\n",
"from src.linfusion import LinFusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_ckpt = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
"device = torch.device('cuda:3')\n",
"\n",
"pipe = AutoPipelineForText2Image.from_pretrained(\n",
" model_ckpt, torch_dtype=torch.float16, variant=\"fp16\"\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"An astronaut floating in space. Beautiful view of the stars and the universe in the background.\"\n",
"generator = torch.manual_seed(0)\n",
"image = pipe(\n",
" prompt, height=512, width=1024, generator=generator\n",
").images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pipe = StableDiffusionXLSuperResPipeline.from_pretrained(\n",
" model_ckpt, torch_dtype=torch.float16, variant=\"fp16\"\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"linfusion = LinFusion.construct_for(pipe)\n",
"pipe.enable_vae_tiling()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator = torch.manual_seed(0)\n",
"image = pipe(image=image, prompt=prompt,\n",
" height=1024, width=2048, device=device, \n",
" num_inference_steps=50, guidance_scale=7.5,\n",
" cosine_scale_1=3, cosine_scale_2=1, cosine_scale_3=1, gaussian_sigma=0.8,\n",
" generator=generator, upscale_strength=0.32).images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator = torch.manual_seed(0)\n",
"image = pipe(image=image, prompt=prompt,\n",
" height=2048, width=4096, device=device, \n",
" num_inference_steps=50, guidance_scale=7.5,\n",
" cosine_scale_1=3, cosine_scale_2=1, cosine_scale_3=1, gaussian_sigma=0.8,\n",
" generator=generator, upscale_strength=0.24).images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"for _, _module in pipe.unet.named_modules():\n",
" if _module.__class__.__name__ == 'BasicTransformerBlock':\n",
" _module.set_chunk_feed_forward(16, 1)\n",
" _module.forward = forward_transformer_block_wrapper(_module)\n",
" elif _module.__class__.__name__ == 'ResnetBlock2D':\n",
" _module.nonlinearity.inplace = True\n",
" _module.forward = forward_resnet_wrapper(_module)\n",
" elif _module.__class__.__name__ == 'CrossAttnDownBlock2D':\n",
" _module.forward = forward_crossattndownblock2d_wrapper(_module)\n",
" elif _module.__class__.__name__ == 'DownBlock2D':\n",
" _module.forward = forward_downblock2d_wrapper(_module)\n",
" elif _module.__class__.__name__ == 'CrossAttnUpBlock2D':\n",
" _module.forward = forward_crossattnupblock2d_wrapper(_module)\n",
" elif _module.__class__.__name__ == 'UpBlock2D':\n",
" _module.forward = forward_upblock2d_wrapper(_module) \n",
"\n",
"pipe.unet.forward = forward_unet_wrapper(pipe.unet)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator = torch.manual_seed(0)\n",
"image = pipe(image=image, prompt=prompt,\n",
" height=4096, width=8192, device=device, \n",
" num_inference_steps=50, guidance_scale=7.5,\n",
" cosine_scale_1=3, cosine_scale_2=1, cosine_scale_3=1, gaussian_sigma=0.8,\n",
" generator=generator, upscale_strength=0.16).images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"generator = torch.manual_seed(0)\n",
"image = pipe(image=image, prompt=prompt,\n",
" height=8192, width=16384, device=device, \n",
" num_inference_steps=50, guidance_scale=7.5,\n",
" cosine_scale_1=3, cosine_scale_2=1, cosine_scale_3=1, gaussian_sigma=0.8,\n",
" generator=generator, upscale_strength=0.08).images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pt2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from src.pipelines.pipeline_highres_sdxl import StableDiffusionXLHighResPipeline\n",
"import torch\n",
"\n",
"from src.linfusion import LinFusion\n",
"from src.tools import seed_everything"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model_ckpt = \"stabilityai/stable-diffusion-xl-base-1.0\"\n",
"device = torch.device('cuda')\n",
"pipe = StableDiffusionXLHighResPipeline.from_pretrained(\n",
" model_ckpt, torch_dtype=torch.float16, variant='fp16'\n",
").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"linfusion = LinFusion.construct_for(pipe)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prompt = \"An astronaut floating in space. Beautiful view of the stars and the universe in the background.\"\n",
"generator = torch.manual_seed(42)\n",
"pipe.enable_vae_tiling()\n",
"images = pipe(prompt,\n",
" height=1024, width=2048, device=device,\n",
" num_inference_steps=50, guidance_scale=7.5,\n",
" cosine_scale_1=3, cosine_scale_2=1, cosine_scale_3=1, gaussian_sigma=0.8,\n",
" show_image=True, generator=generator, upscale_strength=0.32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "bbf4cab5-6bf8-40d9-9f1a-4ffd6d128cd2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline\n",
"import torch\n",
"import os\n",
"import gc\n",
"from PIL import Image\n",
"\n",
"from src.linfusion import LinFusion"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3fa3e7e1-0413-4668-882f-e91bbf20512b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pipeline = StableDiffusionPipeline.from_pretrained(\n",
" \"Lykon/dreamshaper-8\", torch_dtype=torch.float16, variant=\"fp16\"\n",
").to(torch.device(\"cuda\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "825d6621-4cce-4f44-aec3-9d7cb68f75a9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path=\"Yuanshi/LinFusion-1-5\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ee030a1-fe18-4457-a842-b088c3c5d48c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"if not os.path.exists('results'):\n",
" os.mkdir('results')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "831a565e-009e-4cff-a903-5eafc2321724",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"generator = torch.manual_seed(3)\n",
"image = pipeline(\n",
" \"A photo of the Milky Way galaxy\",\n",
" height=512,\n",
" width=1024,\n",
" generator=generator\n",
").images[0]\n",
"image.save('results/output_1k.jpg')\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b3f4ae07-d6a3-425f-940f-6d87000abcb5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(\n",
" \"Lykon/dreamshaper-8\", torch_dtype=torch.float16, variant=\"fp16\"\n",
").to(torch.device(\"cuda\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50d23376-cc8f-4d03-9aa2-0a16d72a259d",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"linfusion = LinFusion.construct_for(pipeline, pretrained_model_name_or_path=\"Yuanshi/LinFusion-1-5\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71e3f3e0-7b8a-404c-b3c1-ed0dfb4c69e9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"init_image = image.resize((2048, 1024))\n",
"generator = torch.manual_seed(3)\n",
"image = pipeline(\n",
" \"A photo of the Milky Way galaxy\",\n",
" image=init_image, strength=0.4, generator=generator).images[0]\n",
"image.save('results/output_2k.jpg')\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f373053d-277e-4710-8a1a-9cf631c81229",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"pipeline.enable_vae_tiling()\n",
"pipeline.vae.tile_sample_min_size = 2048\n",
"pipeline.vae.tile_latent_min_size = 2048 // 8"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ca97038-9e81-465e-ab91-a055a926d87e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"init_image = image.resize((4096, 2048))\n",
"generator = torch.manual_seed(3)\n",
"image = pipeline(\n",
" \"A photo of the Milky Way galaxy\",\n",
" image=init_image, strength=0.3, generator=generator).images[0]\n",
"image.save('results/output_4k.jpg')\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50350beb-5045-4951-a227-fb2269256b2a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"init_image = image.resize((8192, 4096))\n",
"generator = torch.manual_seed(3)\n",
"image = pipeline(\n",
" \"A photo of the Milky Way galaxy\",\n",
" image=init_image, strength=0.2, generator=generator).images[0]\n",
"image.save('results/output_8k.jpg')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b65d6254-989c-4685-bf4b-39288b14a3c9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"gc.collect()\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81bd0a32-e1c9-464d-84d9-42bb83d74f46",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"init_image = image.resize((16384, 8192))\n",
"generator = torch.manual_seed(3)\n",
"image = pipeline(\n",
" \"A photo of the Milky Way galaxy\",\n",
" image=init_image, strength=0.1, generator=generator).images[0]\n",
"image.save('results/output_16k.jpg')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee57bcbf-d248-4d1b-b9cf-f7f3fb2b2ccc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "bf16" --main_process_port 29500 \
-m src.train.distill \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1" \
--mixed_precision="bf16" \
--resolution=768 \
--train_batch_size=6 \
--gradient_accumulation_steps=2 \
--dataloader_num_workers=6 \
--learning_rate=1e-04 \
--weight_decay=0. \
--output_dir="ckpt/linfusion_sd2p1" \
--save_steps=10000
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" --main_process_port 29500 \
-m src.train.distil_sdxl \
--pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
--mixed_precision="fp16" \
--resolution=1024 \
--train_batch_size=2 \
--gradient_accumulation_steps=4 \
--dataloader_num_workers=6 \
--learning_rate=1e-04 \
--weight_decay=0. \
--output_dir="ckpt/linfusion_sdxl" \
--save_steps=10000 \
--mid_dim_scale=16
icon.png

64.5 KB

# 模型唯一标识
modelCode=1057
# 模型名称
modelName=linfusion_pytorch
# 模型描述
modelDescription=高效生成高分辨率图像模型,可实现文生图、图生图等功能
# 应用场景
appScenario=推理,训练,以文生图,科研,教育,政府,金融
# 框架类型
frameType=Pytorch
__version__ = "0.0.1beta1"
from diffusers import ConfigMixin, ModelMixin
from torch import nn
from ..modules.base_module import BaseModule
from ..utils import PatchParallelismCommManager, DistriConfig
class BaseModel(ModelMixin, ConfigMixin):
def __init__(self, model: nn.Module, distri_config: DistriConfig):
super(BaseModel, self).__init__()
self.model = model
self.distri_config = distri_config
self.comm_manager = None
self.buffer_list = None
self.output_buffer = None
self.counter = 0
# for cuda graph
self.static_inputs = None
self.static_outputs = None
self.cuda_graphs = None
def forward(self, *args, **kwargs):
raise NotImplementedError
def set_counter(self, counter: int = 0):
self.counter = counter
for module in self.model.modules():
if isinstance(module, BaseModule):
module.set_counter(counter)
def set_comm_manager(self, comm_manager: PatchParallelismCommManager):
self.comm_manager = comm_manager
for module in self.model.modules():
if isinstance(module, BaseModule):
module.set_comm_manager(comm_manager)
def setup_cuda_graph(self, static_outputs, cuda_graphs):
self.static_outputs = static_outputs
self.cuda_graphs = cuda_graphs
@property
def config(self):
return self.model.config
def synchronize(self):
if self.comm_manager is not None and self.comm_manager.handles is not None:
for i in range(len(self.comm_manager.handles)):
if self.comm_manager.handles[i] is not None:
self.comm_manager.handles[i].wait()
self.comm_manager.handles[i] = None
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import Attention
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist, nn
from .base_model import BaseModel
from ..modules.pp.attn import DistriCrossAttentionPP, DistriSelfAttentionPP, DistriGeneralizedLinearAttentionPP
from ..modules.base_module import BaseModule
from ..modules.pp.conv2d import DistriConv2dPP
from ..modules.pp.groupnorm import DistriGroupNorm
from ..utils import DistriConfig
from ...linfusion import GeneralizedLinearAttention
from typing import Dict, Tuple
class DistriUNetPP(BaseModel): # for Patch Parallelism
def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
assert isinstance(model, UNet2DConditionModel)
if distri_config.world_size > 1 and distri_config.n_device_per_batch > 1:
for name, module in model.named_modules():
if isinstance(module, BaseModule):
continue
for subname, submodule in module.named_children():
if isinstance(submodule, nn.Conv2d):
kernel_size = submodule.kernel_size
if kernel_size == (1, 1) or kernel_size == 1:
continue
wrapped_submodule = DistriConv2dPP(
submodule, distri_config, is_first_layer=subname == "conv_in"
)
setattr(module, subname, wrapped_submodule)
elif isinstance(submodule, GeneralizedLinearAttention):
wrapped_submodule = DistriGeneralizedLinearAttentionPP(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
elif isinstance(submodule, Attention):
if subname == "attn1": # self attention
wrapped_submodule = DistriSelfAttentionPP(submodule, distri_config)
else: # cross attention
assert subname == "attn2"
wrapped_submodule = DistriCrossAttentionPP(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
elif isinstance(submodule, nn.GroupNorm):
wrapped_submodule = DistriGroupNorm(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
super(DistriUNetPP, self).__init__(model, distri_config)
def forward(
self,
sample: torch.FloatTensor,
timestep: torch.Tensor or float or int,
encoder_hidden_states: torch.Tensor,
class_labels: torch.Tensor or None = None,
timestep_cond: torch.Tensor or None = None,
attention_mask: torch.Tensor or None = None,
cross_attention_kwargs: Dict[str, any] or None = None,
added_cond_kwargs: Dict[str, torch.Tensor] or None = None,
down_block_additional_residuals: Tuple[torch.Tensor] or None = None,
mid_block_additional_residual: torch.Tensor or None = None,
down_intrablock_additional_residuals: Tuple[torch.Tensor] or None = None,
encoder_attention_mask: torch.Tensor or None = None,
return_dict: bool = True,
record: bool = False,
):
distri_config = self.distri_config
b, c, h, w = sample.shape
assert (
class_labels is None
and timestep_cond is None
and attention_mask is None
and cross_attention_kwargs is None
and down_block_additional_residuals is None
and mid_block_additional_residual is None
and down_intrablock_additional_residuals is None
and encoder_attention_mask is None
)
if distri_config.use_cuda_graph and not record:
static_inputs = self.static_inputs
if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
assert b == 2
batch_idx = distri_config.batch_idx()
sample = sample[batch_idx : batch_idx + 1]
timestep = (
timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
)
encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
if added_cond_kwargs is not None:
for k in added_cond_kwargs:
added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
assert static_inputs["sample"].shape == sample.shape
static_inputs["sample"].copy_(sample)
if torch.is_tensor(timestep):
if timestep.ndim == 0:
for b in range(static_inputs["timestep"].shape[0]):
static_inputs["timestep"][b] = timestep.item()
else:
assert static_inputs["timestep"].shape == timestep.shape
static_inputs["timestep"].copy_(timestep)
else:
for b in range(static_inputs["timestep"].shape[0]):
static_inputs["timestep"][b] = timestep
assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
if added_cond_kwargs is not None:
for k in added_cond_kwargs:
assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])
if self.counter <= distri_config.warmup_steps:
graph_idx = 0
elif self.counter == distri_config.warmup_steps + 1:
graph_idx = 1
else:
graph_idx = 2
self.cuda_graphs[graph_idx].replay()
output = self.static_outputs[graph_idx]
else:
if distri_config.world_size == 1:
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
assert b == 2
batch_idx = distri_config.batch_idx()
sample = sample[batch_idx : batch_idx + 1]
timestep = (
timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
)
encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
if added_cond_kwargs is not None:
new_added_cond_kwargs = {}
for k in added_cond_kwargs:
new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
added_cond_kwargs = new_added_cond_kwargs
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.output_buffer is None:
self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
if self.buffer_list is None:
self.buffer_list = [torch.empty_like(output) for _ in range(distri_config.world_size)]
dist.all_gather(self.buffer_list, output.contiguous(), async_op=False)
torch.cat(self.buffer_list[: distri_config.n_device_per_batch], dim=2, out=self.output_buffer[0:1])
torch.cat(self.buffer_list[distri_config.n_device_per_batch :], dim=2, out=self.output_buffer[1:2])
output = self.output_buffer
else:
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.output_buffer is None:
self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
if self.buffer_list is None:
self.buffer_list = [torch.empty_like(output) for _ in range(distri_config.world_size)]
output = output.contiguous()
dist.all_gather(self.buffer_list, output, async_op=False)
torch.cat(self.buffer_list, dim=2, out=self.output_buffer)
output = self.output_buffer
if record:
if self.static_inputs is None:
self.static_inputs = {
"sample": sample,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"added_cond_kwargs": added_cond_kwargs,
}
self.synchronize()
if return_dict:
output = UNet2DConditionOutput(sample=output)
else:
output = (output,)
self.counter += 1
return output
@property
def add_embedding(self):
return self.model.add_embedding
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist, nn
from ..modules.base_module import BaseModule
from .base_model import BaseModel
from ..modules.tp.attention import DistriAttentionTP
from ..modules.tp.conv2d import DistriConv2dTP
from ..modules.tp.feed_forward import DistriFeedForwardTP
from ..modules.tp.resnet import DistriResnetBlock2DTP
from ..utils import DistriConfig
class DistriUNetTP(BaseModel): # for Patch Parallelism
def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
assert isinstance(model, UNet2DConditionModel)
if distri_config.world_size > 1 and distri_config.n_device_per_batch > 1:
for name, module in model.named_modules():
if isinstance(module, BaseModule):
continue
for subname, submodule in module.named_children():
if isinstance(submodule, Attention):
wrapped_submodule = DistriAttentionTP(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
elif isinstance(submodule, FeedForward):
wrapped_submodule = DistriFeedForwardTP(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
elif isinstance(submodule, ResnetBlock2D):
wrapped_submodule = DistriResnetBlock2DTP(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
elif isinstance(submodule, nn.Conv2d) and (
subname == "conv_out" or "downsamplers" in name or "upsamplers" in name
):
wrapped_submodule = DistriConv2dTP(submodule, distri_config)
setattr(module, subname, wrapped_submodule)
super(DistriUNetTP, self).__init__(model, distri_config)
def forward(
self,
sample: torch.FloatTensor,
timestep: torch.Tensor or float or int,
encoder_hidden_states: torch.Tensor,
class_labels: torch.Tensor or None = None,
timestep_cond: torch.Tensor or None = None,
attention_mask: torch.Tensor or None = None,
cross_attention_kwargs: dict[str, any] or None = None,
added_cond_kwargs: dict[str, torch.Tensor] or None = None,
down_block_additional_residuals: tuple[torch.Tensor] or None = None,
mid_block_additional_residual: torch.Tensor or None = None,
down_intrablock_additional_residuals: tuple[torch.Tensor] or None = None,
encoder_attention_mask: torch.Tensor or None = None,
return_dict: bool = True,
record: bool = False,
):
distri_config = self.distri_config
b, c, h, w = sample.shape
assert (
class_labels is None
and timestep_cond is None
and attention_mask is None
and cross_attention_kwargs is None
and down_block_additional_residuals is None
and mid_block_additional_residual is None
and down_intrablock_additional_residuals is None
and encoder_attention_mask is None
)
if distri_config.use_cuda_graph and not record:
static_inputs = self.static_inputs
if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
assert b == 2
batch_idx = distri_config.batch_idx()
sample = sample[batch_idx : batch_idx + 1]
timestep = (
timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
)
encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
if added_cond_kwargs is not None:
for k in added_cond_kwargs:
added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
assert static_inputs["sample"].shape == sample.shape
static_inputs["sample"].copy_(sample)
if torch.is_tensor(timestep):
if timestep.ndim == 0:
for b in range(static_inputs["timestep"].shape[0]):
static_inputs["timestep"][b] = timestep.item()
else:
assert static_inputs["timestep"].shape == timestep.shape
static_inputs["timestep"].copy_(timestep)
else:
for b in range(static_inputs["timestep"].shape[0]):
static_inputs["timestep"][b] = timestep
assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
if added_cond_kwargs is not None:
for k in added_cond_kwargs:
assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])
graph_idx = 0
self.cuda_graphs[graph_idx].replay()
output = self.static_outputs[graph_idx]
else:
if distri_config.world_size == 1:
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
assert b == 2
batch_idx = distri_config.batch_idx()
sample = sample[batch_idx : batch_idx + 1]
timestep = (
timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
)
encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
if added_cond_kwargs is not None:
new_added_cond_kwargs = {}
for k in added_cond_kwargs:
new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
added_cond_kwargs = new_added_cond_kwargs
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.output_buffer is None:
self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
if self.buffer_list is None:
self.buffer_list = [torch.empty_like(output) for _ in range(2)]
dist.all_gather(
self.buffer_list, output.contiguous(), group=distri_config.split_group(), async_op=False
)
torch.cat(self.buffer_list, dim=0, out=self.output_buffer)
output = self.output_buffer
else:
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.output_buffer is None:
self.output_buffer = torch.empty_like(output)
self.output_buffer.copy_(output)
output = self.output_buffer
if record:
if self.static_inputs is None:
self.static_inputs = {
"sample": sample,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"added_cond_kwargs": added_cond_kwargs,
}
self.synchronize()
if return_dict:
output = UNet2DConditionOutput(sample=output)
else:
output = (output,)
self.counter += 1
return output
@property
def add_embedding(self):
return self.model.add_embedding
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
from torch import distributed as dist
from typing import Tuple, Dict
from .base_model import BaseModel
from ..utils import DistriConfig
class NaivePatchUNet(BaseModel): # for Patch Parallelism
def __init__(self, model: UNet2DConditionModel, distri_config: DistriConfig):
assert isinstance(model, UNet2DConditionModel)
super(NaivePatchUNet, self).__init__(model, distri_config)
def forward(
self,
sample: torch.FloatTensor,
timestep: torch.Tensor or float or int,
encoder_hidden_states: torch.Tensor,
class_labels: torch.Tensor or None = None,
timestep_cond: torch.Tensor or None = None,
attention_mask: torch.Tensor or None = None,
cross_attention_kwargs: Dict[str, any] or None = None,
added_cond_kwargs: Dict[str, torch.Tensor] or None = None,
down_block_additional_residuals: Tuple[torch.Tensor] or None = None,
mid_block_additional_residual: torch.Tensor or None = None,
down_intrablock_additional_residuals: Tuple[torch.Tensor] or None = None,
encoder_attention_mask: torch.Tensor or None = None,
return_dict: bool = True,
record: bool = False,
):
distri_config = self.distri_config
b, c, h, w = sample.shape
assert (
class_labels is None
and timestep_cond is None
and attention_mask is None
and cross_attention_kwargs is None
and down_block_additional_residuals is None
and mid_block_additional_residual is None
and down_intrablock_additional_residuals is None
and encoder_attention_mask is None
)
if distri_config.use_cuda_graph and not record:
static_inputs = self.static_inputs
if distri_config.world_size > 1 and distri_config.do_classifier_free_guidance and distri_config.split_batch:
assert b == 2
batch_idx = distri_config.batch_idx()
sample = sample[batch_idx : batch_idx + 1]
timestep = (
timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
)
encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
if added_cond_kwargs is not None:
for k in added_cond_kwargs:
added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
assert static_inputs["sample"].shape == sample.shape
static_inputs["sample"].copy_(sample)
if torch.is_tensor(timestep):
if timestep.ndim == 0:
for b in range(static_inputs["timestep"].shape[0]):
static_inputs["timestep"][b] = timestep.item()
else:
assert static_inputs["timestep"].shape == timestep.shape
static_inputs["timestep"].copy_(timestep)
else:
for b in range(static_inputs["timestep"].shape[0]):
static_inputs["timestep"][b] = timestep
assert static_inputs["encoder_hidden_states"].shape == encoder_hidden_states.shape
static_inputs["encoder_hidden_states"].copy_(encoder_hidden_states)
if added_cond_kwargs is not None:
for k in added_cond_kwargs:
assert static_inputs["added_cond_kwargs"][k].shape == added_cond_kwargs[k].shape
static_inputs["added_cond_kwargs"][k].copy_(added_cond_kwargs[k])
graph_idx = 0
if distri_config.split_scheme == "alternate":
graph_idx = self.counter % 2
self.cuda_graphs[graph_idx].replay()
output = self.static_outputs[graph_idx]
else:
if distri_config.world_size == 1:
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
elif distri_config.do_classifier_free_guidance and distri_config.split_batch:
assert b == 2
batch_idx = distri_config.batch_idx()
sample = sample[batch_idx : batch_idx + 1]
timestep = (
timestep[batch_idx : batch_idx + 1] if torch.is_tensor(timestep) and timestep.ndim > 0 else timestep
)
encoder_hidden_states = encoder_hidden_states[batch_idx : batch_idx + 1]
if added_cond_kwargs is not None:
new_added_cond_kwargs = {}
for k in added_cond_kwargs:
new_added_cond_kwargs[k] = added_cond_kwargs[k][batch_idx : batch_idx + 1]
added_cond_kwargs = new_added_cond_kwargs
if distri_config.split_scheme == "row":
split_dim = 2
elif distri_config.split_scheme == "col":
split_dim = 3
elif distri_config.split_scheme == "alternate":
split_dim = 2 if self.counter % 2 == 0 else 3
else:
raise NotImplementedError
if split_dim == 2:
sample = sample.view(1, c, distri_config.n_device_per_batch, -1, w)[:, :, distri_config.split_idx()]
else:
assert split_dim == 3
sample = sample.view(1, c, h, distri_config.n_device_per_batch, -1)[
..., distri_config.split_idx(), :
]
output = self.model(
sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.output_buffer is None:
self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
if self.buffer_list is None:
self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)]
dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False)
buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list]
torch.cat(buffer_list[: distri_config.n_device_per_batch], dim=split_dim, out=self.output_buffer[0:1])
torch.cat(buffer_list[distri_config.n_device_per_batch :], dim=split_dim, out=self.output_buffer[1:2])
output = self.output_buffer
else:
if distri_config.split_scheme == "row":
split_dim = 2
elif distri_config.split_scheme == "col":
split_dim = 3
elif distri_config.split_scheme == "alternate":
split_dim = 2 if self.counter % 2 == 0 else 3
else:
raise NotImplementedError
if split_dim == 2:
sliced_sample = sample.view(b, c, distri_config.n_device_per_batch, -1, w)[
:, :, distri_config.split_idx()
]
else:
assert split_dim == 3
sliced_sample = sample.view(b, c, h, distri_config.n_device_per_batch, -1)[
..., distri_config.split_idx(), :
]
output = self.model(
sliced_sample,
timestep,
encoder_hidden_states,
class_labels=class_labels,
timestep_cond=timestep_cond,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.output_buffer is None:
self.output_buffer = torch.empty((b, c, h, w), device=output.device, dtype=output.dtype)
if self.buffer_list is None:
self.buffer_list = [torch.empty_like(output.view(-1)) for _ in range(distri_config.world_size)]
dist.all_gather(self.buffer_list, output.contiguous().view(-1), async_op=False)
buffer_list = [buffer.view(output.shape) for buffer in self.buffer_list]
torch.cat(buffer_list, dim=split_dim, out=self.output_buffer)
output = self.output_buffer
if record:
if self.static_inputs is None:
self.static_inputs = {
"sample": sample,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"added_cond_kwargs": added_cond_kwargs,
}
self.synchronize()
if return_dict:
output = UNet2DConditionOutput(sample=output)
else:
output = (output,)
self.counter += 1
return output
@property
def add_embedding(self):
return self.model.add_embedding
from torch import nn
from ..utils import DistriConfig
class BaseModule(nn.Module):
def __init__(
self,
module: nn.Module,
distri_config: DistriConfig,
):
super(BaseModule, self).__init__()
self.module = module
self.distri_config = distri_config
self.comm_manager = None
self.counter = 0
self.buffer_list = None
self.idx = None
def forward(self, *args, **kwargs):
raise NotImplementedError
def set_counter(self, counter: int = 0):
self.counter = counter
def set_comm_manager(self, comm_manager):
self.comm_manager = comm_manager
import torch
from diffusers.models.attention import Attention
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F
from ..base_module import BaseModule
from ...utils import DistriConfig
class DistriAttentionPP(BaseModule):
def __init__(self, module: Attention, distri_config: DistriConfig):
super(DistriAttentionPP, self).__init__(module, distri_config)
to_k = module.to_k
to_v = module.to_v
assert isinstance(to_k, nn.Linear)
assert isinstance(to_v, nn.Linear)
assert (to_k.bias is None) == (to_v.bias is None)
assert to_k.weight.shape == to_v.weight.shape
in_size, out_size = to_k.in_features, to_k.out_features
to_kv = nn.Linear(
in_size,
out_size * 2,
bias=to_k.bias is not None,
device=to_k.weight.device,
dtype=to_k.weight.dtype,
)
to_kv.weight.data[:out_size].copy_(to_k.weight.data)
to_kv.weight.data[out_size:].copy_(to_v.weight.data)
if to_k.bias is not None:
assert to_v.bias is not None
to_kv.bias.data[:out_size].copy_(to_k.bias.data)
to_kv.bias.data[out_size:].copy_(to_v.bias.data)
self.to_kv = to_kv
class DistriCrossAttentionPP(DistriAttentionPP):
def __init__(self, module: Attention, distri_config: DistriConfig):
super(DistriCrossAttentionPP, self).__init__(module, distri_config)
self.kv_cache = None
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor or None = None,
scale: float = 1.0,
*args,
**kwargs,
):
assert encoder_hidden_states is not None
recompute_kv = self.counter == 0
attn = self.module
assert isinstance(attn, Attention)
residual = hidden_states
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if recompute_kv or self.kv_cache is None:
kv = self.to_kv(encoder_hidden_states)
self.kv_cache = kv
else:
kv = self.kv_cache
key, value = torch.split(kv, kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
self.counter += 1
return hidden_states
class DistriSelfAttentionPP(DistriAttentionPP):
def __init__(self, module: Attention, distri_config: DistriConfig):
super(DistriSelfAttentionPP, self).__init__(module, distri_config)
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
attn = self.module
distri_config = self.distri_config
assert isinstance(attn, Attention)
residual = hidden_states
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states)
encoder_hidden_states = hidden_states
kv = self.to_kv(encoder_hidden_states)
if distri_config.n_device_per_batch == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = torch.cat([kv for _ in range(distri_config.n_device_per_batch)], dim=1)
elif distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
dist.all_gather(self.buffer_list, kv, group=distri_config.batch_group, async_op=False)
full_kv = torch.cat(self.buffer_list, dim=1)
else:
new_buffer_list = [buffer for buffer in self.buffer_list]
new_buffer_list[distri_config.split_idx()] = kv
full_kv = torch.cat(new_buffer_list, dim=1)
if distri_config.mode != "no_sync":
self.comm_manager.enqueue(self.idx, kv)
key, value = torch.split(full_kv, full_kv.shape[-1] // 2, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor or None = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
distri_config = self.distri_config
if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
if self.comm_manager.handles[self.idx] is not None:
self.comm_manager.handles[self.idx].wait()
self.comm_manager.handles[self.idx] = None
b, l, c = hidden_states.shape
if distri_config.n_device_per_batch > 1 and self.buffer_list is None:
if self.comm_manager.buffer_list is None:
self.idx = self.comm_manager.register_tensor(
shape=(b, l, self.to_kv.out_features), torch_dtype=hidden_states.dtype, layer_type="attn"
)
else:
self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
output = self._forward(hidden_states, scale=scale)
self.counter += 1
return output
class DistriGeneralizedLinearAttentionPP(BaseModule):
def __init__(self, module: Attention, distri_config: DistriConfig):
super(DistriGeneralizedLinearAttentionPP, self).__init__(module, distri_config)
def _forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0):
attn = self.module
distri_config = self.distri_config
assert isinstance(attn, Attention)
residual = hidden_states
batch_size, sequence_length, _ = hidden_states.shape
query = attn.to_q(hidden_states + attn.to_q_(hidden_states))
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states + attn.to_k_(encoder_hidden_states))
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
query = F.elu(query) + 1.0
key = F.elu(key) + 1.0
z = key.mean(dim=-2, keepdim=True).transpose(-2, -1)
kv = (key.transpose(-2, -1) * (sequence_length**-0.5)) @ (
value * (sequence_length**-0.5)
)
kv = torch.cat([kv, z], dim=-1)
if distri_config.n_device_per_batch == 1:
full_kv = kv
else:
if self.buffer_list is None: # buffer not created
full_kv = kv
elif distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
dist.all_gather(self.buffer_list, kv, group=distri_config.batch_group, async_op=False)
full_kv = sum(self.buffer_list) / len(self.buffer_list)
else:
new_buffer_list = [buffer for buffer in self.buffer_list]
new_buffer_list[distri_config.split_idx()] = kv
full_kv = sum(new_buffer_list) / len(new_buffer_list)
if distri_config.mode != "no_sync":
self.comm_manager.enqueue(self.idx, kv)
z = full_kv[:, :, -1:]
z = query @ z + 1e-4
kv = full_kv[:, :, :-1]
hidden_states = query @ kv / z
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor or None = None,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.FloatTensor:
distri_config = self.distri_config
if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
if self.comm_manager.handles[self.idx] is not None:
self.comm_manager.handles[self.idx].wait()
self.comm_manager.handles[self.idx] = None
b, l, c = hidden_states.shape
if distri_config.n_device_per_batch > 1 and self.buffer_list is None:
if self.comm_manager.buffer_list is None:
self.idx = self.comm_manager.register_tensor(
shape=(b * self.module.heads, c // self.module.heads, c // self.module.heads + 1),
torch_dtype=hidden_states.dtype, layer_type="attn"
)
else:
self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
output = self._forward(hidden_states, scale=scale)
self.counter += 1
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment