pipeline_registry.py 3.46 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
# and https://github.com/sgl-project/sglang/blob/v0.4.3/python/sglang/srt/models/registry.py

import importlib
import pkgutil
from dataclasses import dataclass, field
from functools import lru_cache
from typing import AbstractSet, Dict, Optional, Tuple, Type

from fastvideo.v1.logger import init_logger
from fastvideo.v1.pipelines.composed_pipeline_base import ComposedPipelineBase

logger = init_logger(__name__)


@dataclass
class _PipelineRegistry:
    # Keyed by pipeline_arch
    pipelines: Dict[str, Optional[Type[ComposedPipelineBase]]] = field(
        default_factory=dict)

    def get_supported_archs(self) -> AbstractSet[str]:
        return self.pipelines.keys()

    def _try_load_pipeline_cls(
            self, pipeline_arch: str) -> Optional[Type[ComposedPipelineBase]]:
        if pipeline_arch not in self.pipelines:
            return None

        return self.pipelines[pipeline_arch]

    def resolve_pipeline_cls(
        self,
        architecture: str,
    ) -> Tuple[Type[ComposedPipelineBase], str]:
        if not architecture:
            logger.warning("No pipeline architecture is specified")

        pipeline_cls = self._try_load_pipeline_cls(architecture)
        if pipeline_cls is not None:
            return (pipeline_cls, architecture)

        supported_archs = self.get_supported_archs()
        raise ValueError(
            f"Pipeline architectures {architecture} are not supported for now. "
            f"Supported architectures: {supported_archs}")


@lru_cache
def import_pipeline_classes():
    pipeline_arch_name_to_cls = {}
    package_name = "fastvideo.v1.pipelines"
    package = importlib.import_module(package_name)
    for _, name, ispkg in pkgutil.iter_modules(package.__path__,
                                               package_name + "."):
        if ispkg:
            if name.split(".")[-1] == "stages":
                continue
            sub_package_name = name
            sub_package = importlib.import_module(sub_package_name)
            for _, name, ispkg in pkgutil.iter_modules(sub_package.__path__,
                                                       sub_package_name + "."):
                try:
                    module = importlib.import_module(name)
                except Exception as e:
                    logger.warning("Ignore import error when loading %s. %s",
                                   name, e)
                    continue
                if hasattr(module, "EntryClass"):
                    entry = module.EntryClass
                    if isinstance(
                            entry, list
                    ):  # To support multiple pipeline classes in one module
                        for tmp in entry:
                            assert (
                                tmp.__name__ not in pipeline_arch_name_to_cls
                            ), f"Duplicated pipeline implementation for {tmp.__name__}"
                            pipeline_arch_name_to_cls[tmp.__name__] = tmp
                    else:
                        assert (
                            entry.__name__ not in pipeline_arch_name_to_cls
                        ), f"Duplicated pipeline implementation for {entry.__name__}"
                        pipeline_arch_name_to_cls[entry.__name__] = entry
    return pipeline_arch_name_to_cls


PipelineRegistry = _PipelineRegistry(import_pipeline_classes())