animate_model.py 1.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from lightx2v.models.networks.wan.infer.animate.pre_infer import WanAnimatePreInfer
from lightx2v.models.networks.wan.infer.animate.transformer_infer import WanAnimateTransformerInfer
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.animate.transformer_weights import WanAnimateTransformerWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights


class WanAnimateModel(WanModel):
    pre_weight_class = WanPreWeights
    transformer_weight_class = WanAnimateTransformerWeights

    def __init__(self, model_path, config, device):
        self.remove_keys = ["face_encoder", "motion_encoder"]
        super().__init__(model_path, config, device)

    def _init_infer_class(self):
        super()._init_infer_class()
        self.pre_infer_class = WanAnimatePreInfer
        self.transformer_infer_class = WanAnimateTransformerInfer

    def set_animate_encoders(self, motion_encoder, face_encoder):
        self.pre_infer.set_animate_encoders(motion_encoder, face_encoder)