default_runner.py 24.9 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
import gc

import numpy as np
import requests
import torch
import torch.distributed as dist
import torchvision.transforms.functional as TF
from PIL import Image
from loguru import logger
from requests.exceptions import RequestException

from lightx2v.models.runners.base_runner import BaseRunner
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.profiler import *
from lightx2v.utils.utils import get_optimal_patched_size_with_sp, isotropic_crop_resize, save_to_video, wan_vae_to_comfy
from lightx2v_platform.base.global_var import AI_DEVICE

torch_device_module = getattr(torch, AI_DEVICE)


def resize_image(img, resolution, bucket_shape=None):
    assert resolution in ["480p", "540p", "720p"]
    if bucket_shape is None:
        bucket_config = {
            0.667: np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64),
            1.500: np.array([[832, 480], [960, 544], [1280, 720]], dtype=np.int64),
            1.000: np.array([[480, 480], [576, 576], [960, 960]], dtype=np.int64),
        }
    else:
        bucket_config = {}
        for ratio, resolutions in bucket_shape.items():
            bucket_config[float(ratio)] = np.array(resolutions, dtype=np.int64)
    ori_height = img.shape[-2]
    ori_weight = img.shape[-1]
    ori_ratio = ori_height / ori_weight

    aspect_ratios = np.array(np.array(list(bucket_config.keys())))
    closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
    closet_ratio = aspect_ratios[closet_aspect_idx]
    if resolution == "480p":
        target_h, target_w = bucket_config[closet_ratio][0]
    elif resolution == "540p":
        target_h, target_w = bucket_config[closet_ratio][1]
    elif resolution == "720p":
        target_h, target_w = bucket_config[closet_ratio][2]

    cropped_img = isotropic_crop_resize(img, (target_h, target_w))
    logger.info(f"resize_image: {img.shape} -> {cropped_img.shape}, target_h: {target_h}, target_w: {target_w}")
    return cropped_img, target_h, target_w


class DefaultRunner(BaseRunner):
    def __init__(self, config):
        super().__init__(config)
        self.has_prompt_enhancer = False
        self.progress_callback = None
        if self.config["task"] == "t2v" and self.config.get("sub_servers", {}).get("prompt_enhancer") is not None:
            self.has_prompt_enhancer = True
            if not self.check_sub_servers("prompt_enhancer"):
                self.has_prompt_enhancer = False
                logger.warning("No prompt enhancer server available, disable prompt enhancer.")
        if not self.has_prompt_enhancer:
            self.config["use_prompt_enhancer"] = False
        self.set_init_device()
        self.init_scheduler()

    def init_modules(self):
        logger.info("Initializing runner modules...")
        if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
            self.load_model()
        elif self.config.get("lazy_load", False):
            assert self.config.get("cpu_offload", False)
        if hasattr(self, "model"):
            self.model.set_scheduler(self.scheduler)  # set scheduler to model
        if self.config["task"] == "i2v":
            self.run_input_encoder = self._run_input_encoder_local_i2v
        elif self.config["task"] == "flf2v":
            self.run_input_encoder = self._run_input_encoder_local_flf2v
        elif self.config["task"] == "t2v":
            self.run_input_encoder = self._run_input_encoder_local_t2v
        elif self.config["task"] == "vace":
            self.run_input_encoder = self._run_input_encoder_local_vace
        elif self.config["task"] == "animate":
            self.run_input_encoder = self._run_input_encoder_local_animate
        elif self.config["task"] in ["s2v", "rs2v"]:
            self.run_input_encoder = self._run_input_encoder_local_s2v
        elif self.config["task"] == "t2av":
            self.run_input_encoder = self._run_input_encoder_local_t2av
        elif self.config["task"] == "i2av":
            self.run_input_encoder = self._run_input_encoder_local_i2av
        self.config.lock()  # lock config to avoid modification
        if self.config.get("compile", False) and hasattr(self.model, "compile"):
            logger.info(f"[Compile] Compile all shapes: {self.config.get('compile_shapes', [])}")
            self.model.compile(self.config.get("compile_shapes", []))

    def set_init_device(self):
        if self.config["cpu_offload"]:
            self.init_device = torch.device("cpu")
        else:
            self.init_device = torch.device(AI_DEVICE)

    def load_vfi_model(self):
        if self.config["video_frame_interpolation"].get("algo", None) == "rife":
            from lightx2v.models.vfi.rife.rife_comfyui_wrapper import RIFEWrapper

            logger.info("Loading RIFE model...")
            return RIFEWrapper(self.config["video_frame_interpolation"]["model_path"])
        else:
            raise ValueError(f"Unsupported VFI model: {self.config['video_frame_interpolation']['algo']}")

    def load_vsr_model(self):
        if "video_super_resolution" in self.config:
            from lightx2v.models.runners.vsr.vsr_wrapper import VSRWrapper

            logger.info("Loading VSR model...")
            return VSRWrapper(self.config["video_super_resolution"]["model_path"])
        else:
            return None

    @ProfilingContext4DebugL2("Load models")
    def load_model(self):
        self.model = self.load_transformer()
        self.text_encoders = self.load_text_encoder()
        self.image_encoder = self.load_image_encoder()
        self.vae_encoder, self.vae_decoder = self.load_vae()
        self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
        self.vsr_model = self.load_vsr_model() if "video_super_resolution" in self.config else None

    def check_sub_servers(self, task_type):
        urls = self.config.get("sub_servers", {}).get(task_type, [])
        available_servers = []
        for url in urls:
            try:
                status_url = f"{url}/v1/local/{task_type}/generate/service_status"
                response = requests.get(status_url, timeout=2)
                if response.status_code == 200:
                    available_servers.append(url)
                else:
                    logger.warning(f"Service {url} returned status code {response.status_code}")

            except RequestException as e:
                logger.warning(f"Failed to connect to {url}: {str(e)}")
                continue
        logger.info(f"{task_type} available servers: {available_servers}")
        self.config["sub_servers"][task_type] = available_servers
        return len(available_servers) > 0

    def set_inputs(self, inputs):
        self.input_info.seed = inputs.get("seed", 42)
        self.input_info.prompt = inputs.get("prompt", "")
        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = inputs.get("prompt_enhanced", "")
        self.input_info.negative_prompt = inputs.get("negative_prompt", "")
        if "image_path" in self.input_info.__dataclass_fields__:
            self.input_info.image_path = inputs.get("image_path", "")
        if "audio_path" in self.input_info.__dataclass_fields__:
            self.input_info.audio_path = inputs.get("audio_path", "")
        if "video_path" in self.input_info.__dataclass_fields__:
            self.input_info.video_path = inputs.get("video_path", "")
        self.input_info.save_result_path = inputs.get("save_result_path", "")

    def set_config(self, config_modify):
        logger.info(f"modify config: {config_modify}")
        with self.config.temporarily_unlocked():
            self.config.update(config_modify)

    def set_progress_callback(self, callback):
        self.progress_callback = callback

    def run_segment(self, segment_idx=0):
        infer_steps = self.model.scheduler.infer_steps

        for step_index in range(infer_steps):
            # only for single segment, check stop signal every step
            with ProfilingContext4DebugL1(
                f"Run Dit every step",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
                metrics_labels=[step_index + 1, infer_steps],
            ):
                if self.video_segment_num == 1:
                    self.check_stop()
                logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")

                with ProfilingContext4DebugL1("step_pre"):
                    self.model.scheduler.step_pre(step_index=step_index)

                with ProfilingContext4DebugL1("🚀 infer_main"):
                    self.model.infer(self.inputs)

                with ProfilingContext4DebugL1("step_post"):
                    self.model.scheduler.step_post()

                if self.progress_callback:
                    current_step = segment_idx * infer_steps + step_index + 1
                    total_all_steps = self.video_segment_num * infer_steps
                    self.progress_callback((current_step / total_all_steps) * 100, 100)

        if segment_idx is not None and segment_idx == self.video_segment_num - 1:
            del self.inputs
            torch_device_module.empty_cache()

        return self.model.scheduler.latents

    def run_step(self):
        self.inputs = self.run_input_encoder()
        if hasattr(self, "sr_version") and self.sr_version is not None is not None:
            self.config_sr["is_sr_running"] = True
            self.inputs_sr = self.run_input_encoder()
            self.config_sr["is_sr_running"] = False

        self.run_main(total_steps=1)

    def end_run(self):
        self.model.scheduler.clear()
        if hasattr(self, "inputs"):
            del self.inputs
        self.input_info = None
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            if hasattr(self.model, "model") and len(self.model.model) == 2:  # MultiModelStruct
                for model in self.model.model:
                    if hasattr(model.transformer_infer, "offload_manager"):
                        del model.transformer_infer.offload_manager
                        torch_device_module.empty_cache()
                        gc.collect()
                    del model
            else:
                if hasattr(self.model.transformer_infer, "offload_manager"):
                    del self.model.transformer_infer.offload_manager
                    torch_device_module.empty_cache()
                    gc.collect()
                del self.model
        if self.config.get("do_mm_calib", False):
            calib_path = os.path.join(os.getcwd(), "calib.pt")
            torch.save(CALIB, calib_path)
            logger.info(f"[CALIB] Saved calibration data successfully to: {calib_path}")
        torch_device_module.empty_cache()
        gc.collect()

    def read_image_input(self, img_path):
        if isinstance(img_path, Image.Image):
            img_ori = img_path
        else:
            img_ori = Image.open(img_path).convert("RGB")

        if GET_RECORDER_MODE():
            width, height = img_ori.size
            monitor_cli.lightx2v_input_image_len.observe(width * height)
        img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
        self.input_info.original_size = img_ori.size

        if self.config.get("resize_mode", None) == "adaptive":
            img, h, w = resize_image(img, self.config.get("resolution", "480p"), self.config.get("bucket_shape", None))
            logger.info(f"resize_image target_h: {h}, target_w: {w}")
            patched_h = h // self.config["vae_stride"][1] // self.config["patch_size"][1]
            patched_w = w // self.config["vae_stride"][2] // self.config["patch_size"][2]

            patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)

            latent_h = patched_h * self.config["patch_size"][1]
            latent_w = patched_w * self.config["patch_size"][2]

            latent_shape = self.get_latent_shape_with_lat_hw(latent_h, latent_w)
            target_shape = [latent_h * self.config["vae_stride"][1], latent_w * self.config["vae_stride"][2]]

            logger.info(f"target_h: {target_shape[0]}, target_w: {target_shape[1]}, latent_h: {latent_h}, latent_w: {latent_w}")

            img = torch.nn.functional.interpolate(img, size=(target_shape[0], target_shape[1]), mode="bicubic")
            self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
            self.input_info.target_shape = target_shape  # Important: set target_shape in input_info

        return img, img_ori

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_i2v(self):
        img, img_ori = self.read_image_input(self.input_info.image_path)
        clip_encoder_out = self.run_image_encoder(img) if self.config.get("use_image_encoder", True) else None
        vae_encode_out, latent_shape = self.run_vae_encoder(img_ori if self.vae_encoder_need_img_original else img)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
        torch_device_module.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output, img)

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_t2v(self):
        self.input_info.latent_shape = self.get_latent_shape_with_target_hw()  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
        torch_device_module.empty_cache()
        gc.collect()
        return {
            "text_encoder_output": text_encoder_output,
            "image_encoder_output": None,
        }

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_flf2v(self):
        first_frame, _ = self.read_image_input(self.input_info.image_path)
        last_frame, _ = self.read_image_input(self.input_info.last_frame_path)
        clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
        vae_encode_out, latent_shape = self.run_vae_encoder(first_frame, last_frame)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
        torch_device_module.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)

    @ProfilingContext4DebugL2("Run Encoders")
    def _run_input_encoder_local_vace(self):
        src_video = self.input_info.src_video
        src_mask = self.input_info.src_mask
        src_ref_images = self.input_info.src_ref_images
        src_video, src_mask, src_ref_images = self.prepare_source(
            [src_video],
            [src_mask],
            [None if src_ref_images is None else src_ref_images.split(",")],
            (self.config["target_width"], self.config["target_height"]),
        )
        self.src_ref_images = src_ref_images

        vae_encoder_out, latent_shape = self.run_vae_encoder(src_video, src_ref_images, src_mask)
        self.input_info.latent_shape = latent_shape  # Important: set latent_shape in input_info
        text_encoder_output = self.run_text_encoder(self.input_info)
        torch_device_module.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)

    @ProfilingContext4DebugL2("Run Text Encoder")
    def _run_input_encoder_local_animate(self):
        text_encoder_output = self.run_text_encoder(self.input_info)
        torch_device_module.empty_cache()
        gc.collect()
        return self.get_encoder_output_i2v(None, None, text_encoder_output, None)

    def _run_input_encoder_local_s2v(self):
        pass

    def init_run(self):
        self.gen_video_final = None
        self.get_video_segment_num()

        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.model = self.load_transformer()
            self.model.set_scheduler(self.scheduler)

        self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
        if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]:
            self.inputs["image_encoder_output"]["vae_encoder_out"] = None

        if hasattr(self, "sr_version") and self.sr_version is not None:
            self.lq_latents_shape = self.model.scheduler.latents.shape
            self.model_sr.set_scheduler(self.scheduler_sr)
            self.config_sr["is_sr_running"] = True
            self.inputs_sr = self.run_input_encoder()
            self.config_sr["is_sr_running"] = False

    @ProfilingContext4DebugL2("Run DiT")
    def run_main(self):
        self.init_run()
        if self.config.get("compile", False) and hasattr(self.model, "comple"):
            self.model.select_graph_for_compile(self.input_info)
        for segment_idx in range(self.video_segment_num):
            logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
            with ProfilingContext4DebugL1(
                f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                recorder_mode=GET_RECORDER_MODE(),
                metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                metrics_labels=["DefaultRunner"],
            ):
                self.check_stop()
                # 1. default do nothing
                self.init_run_segment(segment_idx)
                # 2. main inference loop
                latents = self.run_segment(segment_idx)
                # 3. vae decoder
                if self.config.get("use_stream_vae", False):
                    frames = []
                    for frame_segment in self.run_vae_decoder_stream(latents):
                        frames.append(frame_segment)
                        logger.info(f"frame sagment: {len(frames)} done")
                    self.gen_video = torch.cat(frames, dim=2)
                else:
                    self.gen_video = self.run_vae_decoder(latents)
                # 4. default do nothing
                self.end_run_segment(segment_idx)
        gen_video_final = self.process_images_after_vae_decoder()
        self.end_run()
        return gen_video_final

    @ProfilingContext4DebugL1("Run VAE Decoder", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
    def run_vae_decoder(self, latents):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae_decoder()
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()))
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
            torch_device_module.empty_cache()
            gc.collect()
        return images

    @ProfilingContext4DebugL1("Run VAE Decoder Stream", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_run_vae_decode_duration, metrics_labels=["DefaultRunner"])
    def run_vae_decoder_stream(self, latents):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae_decoder()

        for frame_segment in self.vae_decoder.decode_stream(latents.to(GET_DTYPE())):
            yield frame_segment

        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
            torch_device_module.empty_cache()
            gc.collect()

    def post_prompt_enhancer(self):
        while True:
            for url in self.config["sub_servers"]["prompt_enhancer"]:
                response = requests.get(f"{url}/v1/local/prompt_enhancer/generate/service_status").json()
                if response["service_status"] == "idle":
                    response = requests.post(
                        f"{url}/v1/local/prompt_enhancer/generate",
                        json={
                            "task_id": generate_task_id(),
                            "prompt": self.config["prompt"],
                        },
                    )
                    enhanced_prompt = response.json()["output"]
                    logger.info(f"Enhanced prompt: {enhanced_prompt}")
                    return enhanced_prompt

    def process_images_after_vae_decoder(self):
        self.gen_video_final = wan_vae_to_comfy(self.gen_video_final)

        if "video_frame_interpolation" in self.config:
            assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
            target_fps = self.config["video_frame_interpolation"]["target_fps"]
            logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
            self.gen_video_final = self.vfi_model.interpolate_frames(
                self.gen_video_final,
                source_fps=self.config.get("fps", 16),
                target_fps=target_fps,
            )

        if self.input_info.return_result_tensor:
            return {"video": self.gen_video_final}
        elif self.input_info.save_result_path is not None:
            if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
                fps = self.config["video_frame_interpolation"]["target_fps"]
            else:
                fps = self.config.get("fps", 16)

            if not dist.is_initialized() or dist.get_rank() == 0:
                logger.info(f"🎬 Start to save video 🎬")

                save_to_video(self.gen_video_final, self.input_info.save_result_path, fps=fps, method="ffmpeg")
                logger.info(f"✅ Video saved successfully to: {self.input_info.save_result_path} ✅")
            return {"video": None}

    @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["DefaultRunner"])
    def run_pipeline(self, input_info):
        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_count.inc()
        self.input_info = input_info

        if self.config["use_prompt_enhancer"]:
            self.input_info.prompt_enhanced = self.post_prompt_enhancer()

        self.inputs = self.run_input_encoder()

        gen_video_final = self.run_main()

        if GET_RECORDER_MODE():
            monitor_cli.lightx2v_worker_request_success.inc()
        return gen_video_final

    def switch_lora(self, lora_path: str, strength: float = 1.0):
        """
        Switch LoRA weights dynamically by calling weight modules' update_lora method.
        If an empty lora_path is provided, it removes LoRA weights by calling weight
        modules' remove_lora method.

        This method allows switching LoRA weights at runtime without reloading the model.
        It calls the model's _update_lora method, which updates LoRA weights in pre_weight,
        transformer_weights, and post_weight modules. Or removes LoRA weights if lora_path
        is empty.

        Args:
            lora_path: Path to the LoRA safetensors file
            strength: LoRA strength (default: 1.0)

        Returns:
            bool: True if LoRA was successfully switched, False otherwise
        """
        if not hasattr(self, "model") or self.model is None:
            logger.error("Model not loaded. Please load model first.")
            return False

        if not hasattr(self.model, "_update_lora"):
            logger.error("Model does not support LoRA switching")
            return False

        try:
            if lora_path == "":
                if hasattr(self.model, "_remove_lora"):
                    logger.info("Removing LoRA weights")
                    self.model._remove_lora()
                    logger.info("LoRA removed successfully")
                    return True
                else:
                    logger.error("Model does not support LoRA removal.")
                    return False
            else:
                logger.info(f"Switching LoRA to: {lora_path} with strength={strength}")
                self.model._update_lora(lora_path, strength)
                logger.info("LoRA switched successfully")
            return True
        except Exception as e:
            logger.error(f"Failed to switch LoRA: {e}")
            return False

    def __del__(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "text_encoders"):
            del self.text_encoders
        if hasattr(self, "image_encoder"):
            del self.image_encoder
        if hasattr(self, "vae_encoder"):
            del self.vae_encoder
        if hasattr(self, "vae_decoder"):
            del self.vae_decoder
        torch_device_module.empty_cache()
        gc.collect()