Unverified Commit fcc2a411 authored by Kane's avatar Kane Committed by GitHub
Browse files

Mlu590 deployment (#453)

Feature:
    1. added mlu590 bfloat16, single-gpu and multi-gpus inference.
    2. added mlu590 int8 inference.
parent 989a30a0
......@@ -62,7 +62,7 @@ class DefaultRunner(BaseRunner):
if self.config["cpu_offload"]:
self.init_device = torch.device("cpu")
else:
self.init_device = torch.device("cuda")
self.init_device = torch.device(self.config.get("run_device", "cuda"))
def load_vfi_model(self):
if self.config["video_frame_interpolation"].get("algo", None) == "rife":
......
......@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = img_path
else:
ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
ref_img, h, w = resize_image(
ref_img,
......@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning"""
device = torch.device("cuda")
device = self.init_device
dtype = GET_DTYPE()
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
......@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload)
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda"))
return model
def load_audio_adapter(self):
......@@ -843,7 +843,7 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload:
device = torch.device("cpu")
else:
device = torch.device("cuda")
device = torch.device(self.config.get("run_device", "cuda"))
audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"],
......@@ -856,6 +856,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload,
device=self.config.get("run_device", "cuda"),
)
audio_adapter.to(device)
......
......@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if clip_offload:
clip_device = torch.device("cpu")
else:
clip_device = torch.device("cuda")
clip_device = torch.device(self.init_device)
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
......@@ -101,7 +101,7 @@ class WanRunner(DefaultRunner):
if t5_offload:
t5_device = torch.device("cpu")
else:
t5_device = torch.device("cuda")
t5_device = torch.device(self.init_device)
# quant_config
t5_quantized = self.config.get("t5_quantized", False)
......@@ -142,7 +142,7 @@ class WanRunner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_device = torch.device(self.init_device)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
......@@ -165,7 +165,7 @@ class WanRunner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_device = torch.device(self.init_device)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
......
......@@ -133,7 +133,7 @@ class QwenImageScheduler(BaseScheduler):
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler"))
with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f)
self.device = torch.device("cuda")
self.device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16
self.guidance_scale = 1.0
......@@ -223,7 +223,7 @@ class QwenImageScheduler(BaseScheduler):
if self.config["task"] == "i2i":
self.generator = torch.Generator().manual_seed(input_info.seed)
elif self.config["task"] == "t2i":
self.generator = torch.Generator(device="cuda").manual_seed(input_info.seed)
self.generator = torch.Generator(device=self.device).manual_seed(input_info.seed)
self.prepare_latents(input_info)
self.prepare_guidance()
self.set_timesteps()
......
......@@ -10,7 +10,7 @@ from lightx2v.utils.utils import masks_like
class WanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.device = torch.device("cuda")
self.device = torch.device(self.config.get("run_device", "cuda"))
self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"]
......
......@@ -33,7 +33,7 @@ class AutoencoderKLQwenImageVAE:
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device("cuda")
self.device = torch.device(self.config.get("run_device", "cuda"))
self.dtype = torch.bfloat16
self.latent_channels = config["vae_z_dim"]
self.load()
......
......@@ -1018,7 +1018,7 @@ class WanVAE:
full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)]
dist.all_gather(full_encoded, encoded_chunk)
torch.cuda.synchronize()
self.device_synchronize()
encoded = torch.cat(full_encoded, dim=split_dim)
......@@ -1100,7 +1100,7 @@ class WanVAE:
dist.all_gather(full_encoded, encoded_chunk)
torch.cuda.synchronize()
self.device_synchronize()
# Reconstruct the full encoded tensor
encoded_rows = []
......@@ -1197,7 +1197,7 @@ class WanVAE:
full_images = [torch.empty_like(images) for _ in range(world_size)]
dist.all_gather(full_images, images)
torch.cuda.synchronize()
self.device_synchronize()
images = torch.cat(full_images, dim=split_dim + 1)
......@@ -1271,7 +1271,7 @@ class WanVAE:
dist.all_gather(full_images, images_chunk)
torch.cuda.synchronize()
self.device_synchronize()
# Reconstruct the full image tensor
image_rows = []
......@@ -1324,3 +1324,11 @@ class WanVAE:
def decode_video(self, vid_enc):
return self.model.decode_video(vid_enc)
def device_synchronize(
self,
):
if "cuda" in str(self.device):
torch.cuda.synchronize()
elif "mlu" in str(self.device):
torch.mlu.synchronize()
......@@ -27,12 +27,12 @@ class _ProfilingContext:
self.metrics_labels = metrics_labels
def __enter__(self):
torch.cuda.synchronize()
self.device_synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.device_synchronize()
elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
......@@ -44,12 +44,12 @@ class _ProfilingContext:
return False
async def __aenter__(self):
torch.cuda.synchronize()
self.device_synchronize()
self.start_time = time.perf_counter()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.device_synchronize()
elapsed = time.perf_counter() - self.start_time
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
......@@ -78,6 +78,15 @@ class _ProfilingContext:
return sync_wrapper
def device_synchronize(
self,
):
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize()
return
class _NullContext:
# Context manager without decision branch logic overhead
......
......@@ -19,8 +19,12 @@ def seed_all(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.manual_seed(seed)
torch.mlu.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment