Unverified Commit b50498fa authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Add lightx2v_platform (#541)

parent 31da6925
......@@ -8,6 +8,7 @@ import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v_platform.base.global_var import AI_DEVICE
def calculate_shift(
......@@ -133,7 +134,6 @@ class QwenImageScheduler(BaseScheduler):
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler"))
with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f)
self.run_device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16
self.guidance_scale = 1.0
......@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
shape = input_info.target_shape
width, height = shape[-1], shape[-2]
latents = randn_tensor(shape, generator=self.generator, device=self.run_device, dtype=self.dtype)
latents = randn_tensor(shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype)
latents = self._pack_latents(latents, self.config["batchsize"], self.config["num_channels_latents"], height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, self.run_device, self.dtype)
latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, AI_DEVICE, self.dtype)
self.latents = latents
self.latent_image_ids = latent_image_ids
......@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
self.run_device,
AI_DEVICE,
sigmas=sigmas,
mu=mu,
)
......@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
def prepare_guidance(self):
# handle guidance
if self.config["guidance_embeds"]:
guidance = torch.full([1], self.guidance_scale, device=self.run_device, dtype=torch.float32)
guidance = torch.full([1], self.guidance_scale, device=AI_DEVICE, dtype=torch.float32)
guidance = guidance.expand(self.latents.shape[0])
else:
guidance = None
......@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i":
self.generator = torch.Generator(device=self.run_device).manual_seed(input_info.seed)
self.generator = torch.Generator(device=AI_DEVICE).manual_seed(input_info.seed)
self.prepare_latents(input_info)
self.prepare_guidance()
self.set_timesteps()
......
......@@ -7,6 +7,7 @@ from loguru import logger
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.utils import masks_like
from lightx2v_platform.base.global_var import AI_DEVICE
class EulerScheduler(WanScheduler):
......@@ -58,14 +59,14 @@ class EulerScheduler(WanScheduler):
)
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed)
self.latents = torch.randn(
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.run_device,
device=AI_DEVICE,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2_audio":
......@@ -77,7 +78,7 @@ class EulerScheduler(WanScheduler):
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.run_device)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=AI_DEVICE)
self.timesteps_ori = self.timesteps.clone()
self.sigmas = self.timesteps_ori / self.num_train_timesteps
......
import torch
from lightx2v_platform.base.global_var import AI_DEVICE
class WanScheduler4ChangingResolutionInterface:
def __new__(cls, father_scheduler, config):
......@@ -20,7 +22,7 @@ class WanScheduler4ChangingResolution:
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed)
self.latents_list = []
for i in range(len(self.config["resolution_rate"])):
self.latents_list.append(
......@@ -30,7 +32,7 @@ class WanScheduler4ChangingResolution:
int(latent_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(latent_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
dtype=dtype,
device=self.run_device,
device=AI_DEVICE,
generator=self.generator,
)
)
......@@ -43,7 +45,7 @@ class WanScheduler4ChangingResolution:
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.run_device,
device=AI_DEVICE,
generator=self.generator,
)
)
......@@ -83,7 +85,7 @@ class WanScheduler4ChangingResolution:
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift + self.changing_resolution_index + 1)
self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift + self.changing_resolution_index + 1)
def add_noise(self, original_samples, noise, timesteps):
sigma = self.sigmas[self.step_index]
......
......@@ -7,12 +7,12 @@ from torch.nn import functional as F
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.utils import masks_like
from lightx2v_platform.base.global_var import AI_DEVICE
class WanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.run_device = torch.device(self.config.get("run_device", "cuda"))
self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"]
......@@ -36,7 +36,7 @@ class WanScheduler(BaseScheduler):
self.rope_params(1024, 2 * (self.head_size // 6)),
],
dim=1,
).to(torch.device(self.run_device))
).to(torch.device(AI_DEVICE))
def rope_params(self, max_seq_len, dim, theta=10000):
assert dim % 2 == 0
......@@ -70,7 +70,7 @@ class WanScheduler(BaseScheduler):
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift)
self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))
......@@ -114,14 +114,14 @@ class WanScheduler(BaseScheduler):
return cos_sin
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.generator = torch.Generator(device=AI_DEVICE).manual_seed(seed)
self.latents = torch.randn(
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.run_device,
device=AI_DEVICE,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
......
......@@ -2,12 +2,12 @@ import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
class WanSFScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.run_device = torch.device(config.get("run_device", "cuda"))
self.dtype = torch.bfloat16
self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"]
self.num_output_frames = self.config["sf_config"]["num_output_frames"]
......@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
self.context_noise = 0
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.latents = torch.randn(latent_shape, device=self.run_device, dtype=self.dtype)
self.latents = torch.randn(latent_shape, device=AI_DEVICE, dtype=self.dtype)
timesteps = []
for frame_block_idx, current_num_frames in enumerate(self.all_num_frames):
frame_steps = []
for step_index, current_timestep in enumerate(self.denoising_step_list):
timestep = torch.ones([self.num_frame_per_block], device=self.run_device, dtype=torch.int64) * current_timestep
timestep = torch.ones([self.num_frame_per_block], device=AI_DEVICE, dtype=torch.int64) * current_timestep
frame_steps.append(timestep)
timesteps.append(frame_steps)
self.timesteps = timesteps
self.noise_pred = torch.zeros(latent_shape, device=self.run_device, dtype=self.dtype)
self.noise_pred = torch.zeros(latent_shape, device=AI_DEVICE, dtype=self.dtype)
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength
if self.extra_one_step:
......@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
self.sigmas_sf = self.sf_shift * self.sigmas_sf / (1 + (self.sf_shift - 1) * self.sigmas_sf)
if self.reverse_sigmas:
self.sigmas_sf = 1 - self.sigmas_sf
self.sigmas_sf = self.sigmas_sf.to(self.run_device)
self.sigmas_sf = self.sigmas_sf.to(AI_DEVICE)
self.timesteps_sf = self.sigmas_sf * self.num_train_timesteps
self.timesteps_sf = self.timesteps_sf.to(self.run_device)
self.timesteps_sf = self.timesteps_sf.to(AI_DEVICE)
self.stream_output = None
......@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
# add noise
if self.step_index < self.infer_steps - 1:
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.run_device, dtype=torch.long)
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=AI_DEVICE, dtype=torch.long)
timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1)
sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1)
noise_next = torch.randn_like(x0_pred)
......
......@@ -4,6 +4,7 @@ from typing import Union
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v_platform.base.global_var import AI_DEVICE
class WanStepDistillScheduler(WanScheduler):
......@@ -19,7 +20,7 @@ class WanStepDistillScheduler(WanScheduler):
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
self.set_denoising_timesteps(device=self.run_device)
self.set_denoising_timesteps(device=AI_DEVICE)
self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))
def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
......
......@@ -5,6 +5,8 @@ from typing import Optional
import torch
from lightx2v_platform.base.global_var import AI_DEVICE
try:
from diffusers import AutoencoderKLQwenImage
from diffusers.image_processor import VaeImageProcessor
......@@ -33,7 +35,7 @@ class AutoencoderKLQwenImageVAE:
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(self.config.get("run_device", "cuda"))
self.device = torch.device(AI_DEVICE)
self.dtype = torch.bfloat16
self.latent_channels = config["vae_z_dim"]
self.load()
......
......@@ -8,6 +8,10 @@ from einops import rearrange
from loguru import logger
from lightx2v.utils.utils import load_weights
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
__all__ = [
"WanVAE",
......@@ -821,11 +825,9 @@ class WanVAE:
use_2d_split=True,
load_from_rank0=False,
use_lightvae=False,
run_device=torch.device("cuda"),
):
self.dtype = dtype
self.device = device
self.run_device = run_device
self.parallel = parallel
self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
......@@ -955,11 +957,11 @@ class WanVAE:
self.scale = [self.mean, self.inv_std]
def to_cuda(self):
self.model.encoder = self.model.encoder.to(self.run_device)
self.model.decoder = self.model.decoder.to(self.run_device)
self.model = self.model.to(self.run_device)
self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda()
self.model.encoder = self.model.encoder.to(AI_DEVICE)
self.model.decoder = self.model.decoder.to(AI_DEVICE)
self.model = self.model.to(AI_DEVICE)
self.mean = self.mean.to(AI_DEVICE)
self.inv_std = self.inv_std.to(AI_DEVICE)
self.scale = [self.mean, self.inv_std]
def encode_dist(self, video, world_size, cur_rank, split_dim):
......@@ -1330,9 +1332,4 @@ class WanVAE:
def device_synchronize(
self,
):
if "cuda" in str(self.run_device):
torch.cuda.synchronize()
elif "mlu" in str(self.run_device):
torch.mlu.synchronize()
elif "npu" in str(self.run_device):
torch.npu.synchronize()
torch_device_module.synchronize()
......@@ -7,6 +7,9 @@ import torch.distributed as dist
from loguru import logger
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class _ProfilingContext:
......@@ -27,12 +30,12 @@ class _ProfilingContext:
self.metrics_labels = metrics_labels
def __enter__(self):
self.device_synchronize()
torch_device_module.synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.device_synchronize()
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
......@@ -44,12 +47,12 @@ class _ProfilingContext:
return False
async def __aenter__(self):
self.device_synchronize()
torch_device_module.synchronize()
self.start_time = time.perf_counter()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.device_synchronize()
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
......@@ -78,17 +81,6 @@ class _ProfilingContext:
return sync_wrapper
def device_synchronize(
self,
):
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize()
elif hasattr(torch, "npu") and torch.npu.is_available():
torch.npu.synchronize()
return
class _NullContext:
# Context manager without decision branch logic overhead
......
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER, PLATFORM_MM_WEIGHT_REGISTER
class Register(dict):
def __init__(self, *args, **kwargs):
super(Register, self).__init__(*args, **kwargs)
......@@ -43,6 +46,15 @@ class Register(dict):
def items(self):
return self._dict.items()
def get(self, key, default=None):
return self._dict.get(key, default)
def merge(self, other_register):
for key, value in other_register.items():
if key in self._dict:
raise Exception(f"{key} already exists in target register.")
self[key] = value
MM_WEIGHT_REGISTER = Register()
ATTN_WEIGHT_REGISTER = Register()
......@@ -54,3 +66,6 @@ TENSOR_REGISTER = Register()
CONVERT_WEIGHT_REGISTER = Register()
EMBEDDING_WEIGHT_REGISTER = Register()
RUNNER_REGISTER = Register()
ATTN_WEIGHT_REGISTER.merge(PLATFORM_ATTN_WEIGHT_REGISTER)
MM_WEIGHT_REGISTER.merge(PLATFORM_MM_WEIGHT_REGISTER)
......@@ -8,6 +8,7 @@ from torch.distributed.tensor.device_mesh import init_device_mesh
from lightx2v.utils.input_info import ALL_INPUT_INFO_KEYS
from lightx2v.utils.lockable_dict import LockableDict
from lightx2v_platform.base.global_var import AI_DEVICE
def get_default_config():
......@@ -92,8 +93,7 @@ def set_parallel_config(config):
cfg_p_size = config["parallel"].get("cfg_p_size", 1)
seq_p_size = config["parallel"].get("seq_p_size", 1)
assert cfg_p_size * seq_p_size == dist.get_world_size(), f"cfg_p_size * seq_p_size must be equal to world_size"
device_str = config.get("run_device", "cuda")
config["device_mesh"] = init_device_mesh(device_str, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
config["device_mesh"] = init_device_mesh(AI_DEVICE, (cfg_p_size, seq_p_size), mesh_dim_names=("cfg_p", "seq_p"))
if config["parallel"] and config["parallel"].get("seq_p_size", False) and config["parallel"]["seq_p_size"] > 1:
config["seq_parallel"] = True
......@@ -101,7 +101,7 @@ def set_parallel_config(config):
if config.get("enable_cfg", False) and config["parallel"] and config["parallel"].get("cfg_p_size", False) and config["parallel"]["cfg_p_size"] > 1:
config["cfg_parallel"] = True
# warmup dist
_a = torch.zeros([1]).to(f"{device_str}:{dist.get_rank()}")
_a = torch.zeros([1]).to(f"{AI_DEVICE}:{dist.get_rank()}")
dist.all_reduce(_a)
......
......@@ -13,18 +13,18 @@ import torchvision
from einops import rearrange
from loguru import logger
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
def seed_all(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.manual_seed(seed)
torch.mlu.manual_seed_all(seed)
torch_device_module.manual_seed(seed)
torch_device_module.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
......
from lightx2v_platform.base.base import check_ai_device, init_ai_device
from lightx2v_platform.base.cambricon_mlu import MluDevice
from lightx2v_platform.base.metax import MetaxDevice
from lightx2v_platform.base.nvidia import CudaDevice
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice"]
from loguru import logger
from lightx2v_platform.base import global_var
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
def init_ai_device(platform="cuda"):
platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None)
if platform_device is None:
available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}")
global_var.AI_DEVICE = platform_device.get_device()
logger.info(f"Initialized AI_DEVICE: {global_var.AI_DEVICE}")
return global_var.AI_DEVICE
def check_ai_device(platform="cuda"):
platform_device = PLATFORM_DEVICE_REGISTER.get(platform, None)
if platform_device is None:
available_platforms = list(PLATFORM_DEVICE_REGISTER.keys())
raise RuntimeError(f"Unsupported platform: {platform}. Available platforms: {available_platforms}")
is_available = platform_device.is_available()
if not is_available:
raise RuntimeError(f"AI device for platform '{platform}' is not available. Please check your runtime environment.")
logger.info(f"AI device for platform '{platform}' is available.")
return True
import torch
import torch.distributed as dist
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
@PLATFORM_DEVICE_REGISTER("mlu")
class MluDevice:
name = "mlu"
@staticmethod
def is_available() -> bool:
try:
import torch_mlu
return torch_mlu.mlu.is_available()
except ImportError:
return False
@staticmethod
def get_device() -> str:
return "mlu"
@staticmethod
def init_parallel_env():
dist.init_process_group(backend="cncl")
torch.mlu.set_device(dist.get_rank())
from lightx2v_platform.base.nvidia import CudaDevice
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
@PLATFORM_DEVICE_REGISTER("metax")
class MetaxDevice(CudaDevice):
name = "cuda"
import torch
import torch.distributed as dist
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
try:
from torch.distributed import ProcessGroupNCCL
except ImportError:
ProcessGroupNCCL = None
@PLATFORM_DEVICE_REGISTER("cuda")
class CudaDevice:
name = "cuda"
@staticmethod
def is_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
@staticmethod
def get_device() -> str:
return "cuda"
@staticmethod
def init_parallel_env():
if ProcessGroupNCCL is None:
raise RuntimeError("ProcessGroupNCCL is not available. Please check your runtime environment.")
pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
dist.init_process_group(backend="nccl", pg_options=pg_options)
torch.cuda.set_device(dist.get_rank())
from lightx2v_platform.base.global_var import AI_DEVICE
if AI_DEVICE == "mlu":
from .attn.cambricon_mlu import *
from .mm.cambricon_mlu import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment