Unverified Commit 4c0a9a0d authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

Fix device bugs (#527)

parent fbb19ffc
......@@ -133,7 +133,7 @@ class QwenImageScheduler(BaseScheduler):
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler"))
with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f)
self.device = torch.device(self.config.get("run_device", "cuda"))
self.run_device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16
self.guidance_scale = 1.0
......@@ -176,9 +176,9 @@ class QwenImageScheduler(BaseScheduler):
shape = input_info.target_shape
width, height = shape[-1], shape[-2]
latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype)
latents = randn_tensor(shape, generator=self.generator, device=self.run_device, dtype=self.dtype)
latents = self._pack_latents(latents, self.config["batchsize"], self.config["num_channels_latents"], height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, self.device, self.dtype)
latent_image_ids = self._prepare_latent_image_ids(self.config["batchsize"], height // 2, width // 2, self.run_device, self.dtype)
self.latents = latents
self.latent_image_ids = latent_image_ids
......@@ -198,7 +198,7 @@ class QwenImageScheduler(BaseScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
self.device,
self.run_device,
sigmas=sigmas,
mu=mu,
)
......@@ -213,7 +213,7 @@ class QwenImageScheduler(BaseScheduler):
def prepare_guidance(self):
# handle guidance
if self.config["guidance_embeds"]:
guidance = torch.full([1], self.guidance_scale, device=self.device, dtype=torch.float32)
guidance = torch.full([1], self.guidance_scale, device=self.run_device, dtype=torch.float32)
guidance = guidance.expand(self.latents.shape[0])
else:
guidance = None
......@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i":
self.generator = torch.Generator(device=self.device).manual_seed(input_info.seed)
self.generator = torch.Generator(device=self.run_device).manual_seed(input_info.seed)
self.prepare_latents(input_info)
self.prepare_guidance()
self.set_timesteps()
......
......@@ -58,14 +58,14 @@ class EulerScheduler(WanScheduler):
)
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.latents = torch.randn(
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
device=self.run_device,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2_audio":
......@@ -77,7 +77,7 @@ class EulerScheduler(WanScheduler):
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.device)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=self.run_device)
self.timesteps_ori = self.timesteps.clone()
self.sigmas = self.timesteps_ori / self.num_train_timesteps
......
......@@ -20,7 +20,7 @@ class WanScheduler4ChangingResolution:
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.latents_list = []
for i in range(len(self.config["resolution_rate"])):
self.latents_list.append(
......@@ -30,7 +30,7 @@ class WanScheduler4ChangingResolution:
int(latent_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(latent_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
dtype=dtype,
device=self.device,
device=self.run_device,
generator=self.generator,
)
)
......@@ -43,7 +43,7 @@ class WanScheduler4ChangingResolution:
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
device=self.run_device,
generator=self.generator,
)
)
......@@ -83,7 +83,7 @@ class WanScheduler4ChangingResolution:
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + self.changing_resolution_index + 1)
self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift + self.changing_resolution_index + 1)
def add_noise(self, original_samples, noise, timesteps):
sigma = self.sigmas[self.step_index]
......
......@@ -10,11 +10,10 @@ from lightx2v.utils.utils import masks_like
class WanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.device = torch.device(self.config.get("run_device", "cuda"))
self.run_device = torch.device(self.config.get("run_device", "cuda"))
self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"]
self.run_device = self.config.get("run_device", "cuda")
self.patch_size = (1, 2, 2)
self.shift = 1
self.num_train_timesteps = 1000
......@@ -65,7 +64,7 @@ class WanScheduler(BaseScheduler):
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
self.set_timesteps(self.infer_steps, device=self.run_device, shift=self.sample_shift)
self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))
......@@ -93,14 +92,14 @@ class WanScheduler(BaseScheduler):
return cos_sin
def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.generator = torch.Generator(device=self.run_device).manual_seed(seed)
self.latents = torch.randn(
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
device=self.run_device,
generator=self.generator,
)
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
......
......@@ -7,7 +7,7 @@ from lightx2v.utils.envs import *
class WanSFScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.device = torch.device("cuda")
self.run_device = torch.device(config.get("run_device"), "cuda")
self.dtype = torch.bfloat16
self.num_frame_per_block = self.config["sf_config"]["num_frame_per_block"]
self.num_output_frames = self.config["sf_config"]["num_output_frames"]
......@@ -27,20 +27,20 @@ class WanSFScheduler(WanScheduler):
self.context_noise = 0
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.latents = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
self.latents = torch.randn(latent_shape, device=self.run_device, dtype=self.dtype)
timesteps = []
for frame_block_idx, current_num_frames in enumerate(self.all_num_frames):
frame_steps = []
for step_index, current_timestep in enumerate(self.denoising_step_list):
timestep = torch.ones([self.num_frame_per_block], device=self.device, dtype=torch.int64) * current_timestep
timestep = torch.ones([self.num_frame_per_block], device=self.run_device, dtype=torch.int64) * current_timestep
frame_steps.append(timestep)
timesteps.append(frame_steps)
self.timesteps = timesteps
self.noise_pred = torch.zeros(latent_shape, device=self.device, dtype=self.dtype)
self.noise_pred = torch.zeros(latent_shape, device=self.run_device, dtype=self.dtype)
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * self.denoising_strength
if self.extra_one_step:
......@@ -52,10 +52,10 @@ class WanSFScheduler(WanScheduler):
self.sigmas_sf = self.sf_shift * self.sigmas_sf / (1 + (self.sf_shift - 1) * self.sigmas_sf)
if self.reverse_sigmas:
self.sigmas_sf = 1 - self.sigmas_sf
self.sigmas_sf = self.sigmas_sf.to(self.device)
self.sigmas_sf = self.sigmas_sf.to(self.run_device)
self.timesteps_sf = self.sigmas_sf * self.num_train_timesteps
self.timesteps_sf = self.timesteps_sf.to(self.device)
self.timesteps_sf = self.timesteps_sf.to(self.run_device)
self.stream_output = None
......@@ -93,7 +93,7 @@ class WanSFScheduler(WanScheduler):
# add noise
if self.step_index < self.infer_steps - 1:
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.device, dtype=torch.long)
timestep_next = self.timesteps[self.seg_index][self.step_index + 1] * torch.ones(self.num_frame_per_block, device=self.run_device, dtype=torch.long)
timestep_id_next = torch.argmin((self.timesteps_sf.unsqueeze(0) - timestep_next.unsqueeze(1)).abs(), dim=1)
sigma_next = self.sigmas_sf[timestep_id_next].reshape(-1, 1, 1, 1)
noise_next = torch.randn_like(x0_pred)
......
......@@ -19,7 +19,7 @@ class WanStepDistillScheduler(WanScheduler):
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
self.set_denoising_timesteps(device=self.device)
self.set_denoising_timesteps(device=self.run_device)
def set_denoising_timesteps(self, device: Union[str, torch.device] = None):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min)
......
......@@ -1330,9 +1330,9 @@ class WanVAE:
def device_synchronize(
self,
):
if "cuda" in str(self.device):
if "cuda" in str(self.run_device):
torch.cuda.synchronize()
elif "mlu" in str(self.device):
elif "mlu" in str(self.run_device):
torch.mlu.synchronize()
elif "npu" in str(self.device):
elif "npu" in str(self.run_device):
torch.npu.synchronize()
#!/bin/bash
lightx2v_path=/mtc/gushiqiao/llmc_workspace/lightx2v_latest2/LightX2V
model_path=/data/nvme0/gushiqiao/models/Lightx2v_models/seko-new/SekoTalk-Distill-fp8/
lightx2v_path=/path/to/LightX2V
model_path=/path/to/SekoTalk-Distill-fp8/
export CUDA_VISIBLE_DEVICES=0
......
#!/bin/bash
lightx2v_path=/path/to/Lightx2v
model_path=/path/to/SekoTalk-Distill
export MLU_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls seko_talk \
--task s2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/seko_talk/mlu/seko_talk_bf16.json \
--prompt "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/audio/seko_input.png \
--audio_path ${lightx2v_path}/assets/inputs/audio/seko_input.mp3 \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_seko_talk.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment