".github/vscode:/vscode.git/clone" did not exist on "ac289b35d23a0e921fa2782bb8d29a513f2b91c0"
Unverified Commit fcc2a411 authored by Kane's avatar Kane Committed by GitHub
Browse files

Mlu590 deployment (#453)

Feature:
    1. added mlu590 bfloat16, single-gpu and multi-gpus inference.
    2. added mlu590 int8 inference.
parent 989a30a0
...@@ -62,7 +62,7 @@ class DefaultRunner(BaseRunner): ...@@ -62,7 +62,7 @@ class DefaultRunner(BaseRunner):
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.init_device = torch.device("cpu") self.init_device = torch.device("cpu")
else: else:
self.init_device = torch.device("cuda") self.init_device = torch.device(self.config.get("run_device", "cuda"))
def load_vfi_model(self): def load_vfi_model(self):
if self.config["video_frame_interpolation"].get("algo", None) == "rife": if self.config["video_frame_interpolation"].get("algo", None) == "rife":
......
...@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = img_path ref_img = img_path
else: else:
ref_img = load_image(img_path) ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
ref_img, h, w = resize_image( ref_img, h, w = resize_image(
ref_img, ref_img,
...@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]: def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning""" """Prepare previous latents for conditioning"""
device = torch.device("cuda") device = self.init_device
dtype = GET_DTYPE() dtype = GET_DTYPE()
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1] tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self): def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large")) audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False)) audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload) model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda"))
return model return model
def load_audio_adapter(self): def load_audio_adapter(self):
...@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload: if audio_adapter_offload:
device = torch.device("cpu") device = torch.device("cpu")
else: else:
device = torch.device("cuda") device = torch.device(self.config.get("run_device", "cuda"))
audio_adapter = AudioAdapter( audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"], attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"], num_attention_heads=self.config["num_heads"],
...@@ -856,6 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -856,6 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False), quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None), quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload, cpu_offload=audio_adapter_offload,
device=self.config.get("run_device", "cuda"),
) )
audio_adapter.to(device) audio_adapter.to(device)
......
...@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner): ...@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if clip_offload: if clip_offload:
clip_device = torch.device("cpu") clip_device = torch.device("cpu")
else: else:
clip_device = torch.device("cuda") clip_device = torch.device(self.init_device)
# quant_config # quant_config
clip_quantized = self.config.get("clip_quantized", False) clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized: if clip_quantized:
...@@ -101,7 +101,7 @@ class WanRunner(DefaultRunner): ...@@ -101,7 +101,7 @@ class WanRunner(DefaultRunner):
if t5_offload: if t5_offload:
t5_device = torch.device("cpu") t5_device = torch.device("cpu")
else: else:
t5_device = torch.device("cuda") t5_device = torch.device(self.init_device)
# quant_config # quant_config
t5_quantized = self.config.get("t5_quantized", False) t5_quantized = self.config.get("t5_quantized", False)
...@@ -142,7 +142,7 @@ class WanRunner(DefaultRunner): ...@@ -142,7 +142,7 @@ class WanRunner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device("cuda") vae_device = torch.device(self.init_device)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
...@@ -165,7 +165,7 @@ class WanRunner(DefaultRunner): ...@@ -165,7 +165,7 @@ class WanRunner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device("cuda") vae_device = torch.device(self.init_device)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
......
...@@ -133,7 +133,7 @@ class QwenImageScheduler(BaseScheduler): ...@@ -133,7 +133,7 @@ class QwenImageScheduler(BaseScheduler):
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler")) self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler"))
with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f: with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f) self.scheduler_config = json.load(f)
self.device = torch.device("cuda") self.device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.guidance_scale = 1.0 self.guidance_scale = 1.0
...@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler): ...@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if self.config["task"] == "i2i": if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed) self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i": elif self.config["task"] == "t2i":
self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed) self.generator = torch.Generator(device=self.device).manual_seed(input_info.seed)
self.prepare_latents(input_info) self.prepare_latents(input_info)
self.prepare_guidance() self.prepare_guidance()
self.set_timesteps() self.set_timesteps()
......
...@@ -10,7 +10,7 @@ from lightx2v.utils.utils import masks_like ...@@ -10,7 +10,7 @@ from lightx2v.utils.utils import masks_like
class WanScheduler(BaseScheduler): class WanScheduler(BaseScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.device = torch.device("cuda") self.device = torch.device(self.config.get("run_device", "cuda"))
self.infer_steps = self.config["infer_steps"] self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"] self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"] self.sample_shift = self.config["sample_shift"]
......
...@@ -33,7 +33,7 @@ class AutoencoderKLQwenImageVAE: ...@@ -33,7 +33,7 @@ class AutoencoderKLQwenImageVAE:
if self.cpu_offload: if self.cpu_offload:
self.device = torch.device("cpu") self.device = torch.device("cpu")
else: else:
self.device = torch.device("cuda") self.device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
self.latent_channels = config["vae_z_dim"] self.latent_channels = config["vae_z_dim"]
self.load() self.load()
......
...@@ -1018,7 +1018,7 @@ class WanVAE: ...@@ -1018,7 +1018,7 @@ class WanVAE:
full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)] full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)]
dist.all_gather(full_encoded, encoded_chunk) dist.all_gather(full_encoded, encoded_chunk)
torch.cuda.synchronize() self.device_synchronize()
encoded = torch.cat(full_encoded, dim=split_dim) encoded = torch.cat(full_encoded, dim=split_dim)
...@@ -1100,7 +1100,7 @@ class WanVAE: ...@@ -1100,7 +1100,7 @@ class WanVAE:
dist.all_gather(full_encoded, encoded_chunk) dist.all_gather(full_encoded, encoded_chunk)
torch.cuda.synchronize() self.device_synchronize()
# Reconstruct the full encoded tensor # Reconstruct the full encoded tensor
encoded_rows = [] encoded_rows = []
...@@ -1197,7 +1197,7 @@ class WanVAE: ...@@ -1197,7 +1197,7 @@ class WanVAE:
full_images = [torch.empty_like(images) for _ in range(world_size)] full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images) dist.all_gather(full_images, images)
torch.cuda.synchronize() self.device_synchronize()
images = torch.cat(full_images, dim=split_dim + 1) images = torch.cat(full_images, dim=split_dim + 1)
...@@ -1271,7 +1271,7 @@ class WanVAE: ...@@ -1271,7 +1271,7 @@ class WanVAE:
dist.all_gather(full_images, images_chunk) dist.all_gather(full_images, images_chunk)
torch.cuda.synchronize() self.device_synchronize()
# Reconstruct the full image tensor # Reconstruct the full image tensor
image_rows = [] image_rows = []
...@@ -1324,3 +1324,11 @@ class WanVAE: ...@@ -1324,3 +1324,11 @@ class WanVAE:
def decode_video(self, vid_enc): def decode_video(self, vid_enc):
return self.model.decode_video(vid_enc) return self.model.decode_video(vid_enc)
def device_synchronize(
self,
):
if "cuda" in str(self.device):
torch.cuda.synchronize()
elif "mlu" in str(self.device):
torch.mlu.synchronize()
...@@ -27,12 +27,12 @@ class _ProfilingContext: ...@@ -27,12 +27,12 @@ class _ProfilingContext:
self.metrics_labels = metrics_labels self.metrics_labels = metrics_labels
def __enter__(self): def __enter__(self):
torch.cuda.synchronize() self.device_synchronize()
self.start_time = time.perf_counter() self.start_time = time.perf_counter()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize() self.device_synchronize()
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func: if self.enable_recorder and self.metrics_func:
if self.metrics_labels: if self.metrics_labels:
...@@ -44,12 +44,12 @@ class _ProfilingContext: ...@@ -44,12 +44,12 @@ class _ProfilingContext:
return False return False
async def __aenter__(self): async def __aenter__(self):
torch.cuda.synchronize() self.device_synchronize()
self.start_time = time.perf_counter() self.start_time = time.perf_counter()
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize() self.device_synchronize()
elapsed = time.perf_counter() - self.start_time elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func: if self.enable_recorder and self.metrics_func:
if self.metrics_labels: if self.metrics_labels:
...@@ -78,6 +78,15 @@ class _ProfilingContext: ...@@ -78,6 +78,15 @@ class _ProfilingContext:
return sync_wrapper return sync_wrapper
def device_synchronize(
self,
):
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize()
return
class _NullContext: class _NullContext:
# Context manager without decision branch logic overhead # Context manager without decision branch logic overhead
......
...@@ -19,8 +19,12 @@ def seed_all(seed): ...@@ -19,8 +19,12 @@ def seed_all(seed):
os.environ["PYTHONHASHSEED"] = str(seed) os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.manual_seed(seed)
torch.mlu.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment