test_register.py 16.8 KB
Newer Older
zzg_666's avatar
zzg_666 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
from dataflow.utils.registry import OPERATOR_REGISTRY, PROMPT_REGISTRY
from inspect import signature
from pprint import pprint
import pytest
from inspect import isclass, getmembers, isfunction

def build_tree(type_dict):
    """
    根据 type_dict 构建层级统计树
    """
    tree = {}

    for op_name, categories in type_dict.items():
        node = tree
        for cat in categories:
            if cat not in node:
                node[cat] = {"__count__": 0, "__children__": {}}
            node[cat]["__count__"] += 1
            node = node[cat]["__children__"]
    return tree


def print_tree(tree, indent=0):
    """
    递归打印树状统计结果
    """
    for cat, info in tree.items():
        print("  " * indent + f"- {cat} ({info['__count__']})")
        print_tree(info["__children__"], indent + 1)


@pytest.mark.cpu
def test_all_operator_registry():
    """
    Test function to check the operator registry.
    This will print all registered operators and their signatures.
    """
    # Get the operator map
    OPERATOR_REGISTRY._get_all()
    print(OPERATOR_REGISTRY)
    dataflow_obj_map = OPERATOR_REGISTRY.get_obj_map()

    # pprint(dataflow_obj_map)
    # print typedict of all operators
    print("\nTypedict of all operators:")
    type_dict = OPERATOR_REGISTRY.get_type_of_objects()
    # pprint(type_dict)

    # ---- 数量匹配检查 ----
    num_by_typedict = len(type_dict)   # 新格式:key 就是 operator
    num_by_registry = len(dataflow_obj_map)

    print(f"\n{num_by_typedict} operators in total by type dict.")
    print(f"{num_by_registry} operators registered in the registry.")

    if num_by_typedict != num_by_registry:
        print("Mismatch found:")
        # 找出 type_dict 有但 registry 没有的
        for op in type_dict.keys():
            if op not in dataflow_obj_map:
                print(f"  [Missing in registry] {op}")
        # 找出 registry 有但 type_dict 没有的
        for op in dataflow_obj_map.keys():
            if op not in type_dict:
                print(f"  [Missing in type_dict] {op}")

    # ---- 层级统计树 ----
    tree = build_tree(type_dict)
    print("\nOperator Type Hierarchy Statistics:")
    print_tree(tree)

    # ---- 逐个打印信息 ----
    assert len(dataflow_obj_map) > 0, "No operators found in the registry."
    # for name, obj in dataflow_obj_map.items():
    #     print(f"\nOperator Name: {name}, Class: {obj.__name__}")
    #     if hasattr(obj, 'run'):
    #         run_signature = signature(obj.run)
    #         print(f"  run signature: {run_signature}")
    #     if hasattr(obj, '__init__'):
    #         init_signature = signature(obj.__init__)
    #         print(f"  __init__ signature: {init_signature}")

# ================= Enhanced checks start here =================
    print("\n🔍 Checking Operator class __init__ and run signatures ...")

    # 汇总问题以便一次性 fail
    init_issues = []           # __init__ 问题(含 prompt_template 规则)
    invalid_run_param_ops = [] # run() 形参命名或顺序问题
    operator_signatures = {}   # 收集签名信息用于打印

    # 依赖类型
    from dataflow.core.prompt import DIYPromptABC, PromptABC
    import inspect
    from inspect import Signature
    from typing import get_origin, get_args, Union
    try:
        from types import UnionType  # Py3.10+
    except Exception:
        UnionType = None

    def _iter_annotation_types(ann):
        """
        解析注解为一组类型;支持 typing.Union 与 PEP 604 (A | B)。
        返回:类型对象列表;如果注解缺失/不可用,返回 []。
        """
        if ann is inspect._empty:
            return []
        origin = get_origin(ann)
        if origin is Union:
            return [t for t in get_args(ann) if isinstance(t, type)]
        if UnionType is not None and isinstance(ann, UnionType):
            # 理论上 get_origin/get_args 也能处理 PEP 604,但做双保险
            return [t for t in getattr(ann, "__args__", ()) if isinstance(t, type)]
        if isinstance(ann, type):
            return [ann]
        # 其它复杂注解(如 ForwardRef、TypedDict 等)此处不做深解析
        return []

    for name, cls in dataflow_obj_map.items():
        if not isclass(cls):
            continue

        cls_info = {"__init__": None, "run": None}

        # ---------- __init__ 检查 ----------
        # 规则:
        # 1) 若存在参数 prompt_template:
        #    a) 必须有默认值(可选参数)
        #    b) 注解必须包含 DIYPromptABC
        #    c) 其它类型必须是 PromptABC 的子类
        # 2) 任何 signature 获取失败都应计入 init_issues
        try:
            # 注意:很多类可能未定义 __init__,由 object.__init__ 继承,这里也能拿到签名
            init_sig: Signature = signature(cls.__init__)
            params = init_sig.parameters
            cls_info["__init__"] = list(params.keys())

            if "prompt_template" in params:
                p = params["prompt_template"]

                # 2.a 默认值必须存在
                if p.default is inspect._empty:
                    init_issues.append(
                        (name, cls.__module__,
                        "Parameter 'prompt_template' must have a default value (be optional).")
                    )

                # 2.b/2.c 注解类型要求
                types_ = _iter_annotation_types(p.annotation)
                if not types_:
                    init_issues.append(
                        (name, cls.__module__,
                        "Parameter 'prompt_template' must be type-annotated and include DIYPromptABC "
                        "(Union[DIYPromptABC, ...] or DIYPromptABC).")
                    )
                else:
                    has_diy = any(issubclass(t, DIYPromptABC) if isinstance(t, type) else False
                                for t in types_)
                    if not has_diy:
                        init_issues.append(
                            (name, cls.__module__,
                            "Annotation of 'prompt_template' must include DIYPromptABC.")
                        )
                    # 其它类型必须是 PromptABC 子类
                    for t in types_:
                        if t is DIYPromptABC:
                            continue
                        if not (isinstance(t, type) and issubclass(t, PromptABC)):
                            init_issues.append(
                                (name, cls.__module__,
                                f"Invalid allowed type in 'prompt_template' annotation: {t}. "
                                f"All non-DIY types must subclass PromptABC.")
                            )

        except Exception as e:
            cls_info["__init__"] = f"Error: {e}"
            init_issues.append(
                (name, cls.__module__, f"Failed to inspect __init__: {e}")
            )

        # ---------- run() 检查 ----------
        # 规则:
        # - 除 self/cls 外,参数名必须以 input_* 或 output_*,或等于 'storage'
        # - 'storage' 必须是第一个(除 self/cls)参数
        if hasattr(cls, "run"):
            try:
                run_sig: Signature = signature(cls.run)
                params = list(run_sig.parameters.keys())
                cls_info["run"] = params

                # 过滤掉 self/cls
                logical_params = [p for p in params if p not in ("self", "cls")]

                invalid_params = [
                    p for p in logical_params
                    if p != "storage" and not (p.startswith("input_") or p.startswith("output_"))
                ]

                # 'storage' 必须存在且为第一个逻辑参数
                if "storage" not in logical_params:
                    invalid_params.append("'storage' parameter missing")
                else:
                    if logical_params[0] != "storage":
                        invalid_params.append(
                            f"'storage' should be the FIRST parameter after self/cls, "
                            f"but found at position {logical_params.index('storage')} "
                            f"with first logical param '{logical_params[0] if logical_params else None}'"
                        )

                if invalid_params:
                    invalid_run_param_ops.append((name, cls.__module__, invalid_params))
            except Exception as e:
                cls_info["run"] = f"Error: {e}"
                # 将获取 run 签名失败也视为不合规
                invalid_run_param_ops.append((name, cls.__module__, [f"Failed to inspect run(): {e}"]))

        operator_signatures[name] = cls_info

    # ---------- 打印签名汇总 ----------
    print("\n📘 Operator signatures summary:")
    for op_name, info in operator_signatures.items():
        print(f"\nOperator: {op_name}")
        print(f"  __init__ params: {info['__init__']}")
        print(f"  run params: {info['run']}")

    # ---------- 先处理 __init__ 问题 ----------
    if init_issues:
        details = "\n".join(
            f"  • {name} ({module}) → {msg}"
            for name, module, msg in init_issues
        )
        rule_explanation = (
            "\n__init__ / prompt_template rules:\n"
            "  - If parameter 'prompt_template' exists:\n"
            "    • It MUST have a default value (be optional).\n"
            "    • Its type annotation MUST include DIYPromptABC.\n"
            "    • All other allowed types MUST subclass PromptABC.\n"
            "  - Any failure to inspect __init__ is considered an error.\n"
            "  - See: dataflow.operators.reasoning.generate.ReasoningAnswerGenerator for a reference implementation.\n"
        )
        pytest.fail(
            f"❌ Found {len(init_issues)} operators violating __init__/prompt_template rules."
            f"{rule_explanation}\nDetails:\n{details}",
            pytrace=False,
        )

    # ---------- 再处理 run() 形参问题 ----------
    if invalid_run_param_ops:
        details = "\n".join(
            f"  • {name} ({module}) → invalid run parameters: {invalids}"
            for name, module, invalids in invalid_run_param_ops
        )
        rule_explanation = (
            "\nOperator run() parameter naming rule:\n"
            "  - All parameters (excluding self/cls) must be explicitly named using:\n"
            "      • input_*   • output_*   • 'storage'\n"
            "  - 'storage' MUST be the first parameter after self/cls.\n"
            "Example:\n"
            "  def run(self, storage, input_text, input_image, output_result):\n"
        )
        pytest.fail(
            f"❌ Found {len(invalid_run_param_ops)} operators violating run() parameter rules."
            f"{rule_explanation}\nDetails:\n{details}",
            pytrace=False,
        )

    print("✅ All checks passed: __init__/prompt_template rules and run() parameter conventions.")


    # ======= prompt registry test ==============
    print("\nPrompt Registry:")
    # PROMPT_REGISTRY._get_all() # will cause bug and ERROR
    print(PROMPT_REGISTRY)
    prompt_type_dict = PROMPT_REGISTRY.get_type_of_objects()
    print("\nPrompt Type Hierarchy Statistics:")
    print_tree(build_tree(prompt_type_dict))

    # 成员函数检测逻辑
    print("\n🔍 Checking Prompt class member functions ...")
    allowed_methods = {"build_prompt", "__init__", "build_system_prompt"}
    invalid_prompts = []

    prompt_map = PROMPT_REGISTRY.get_obj_map()
    for name, cls in prompt_map.items():
        if cls is None or not isclass(cls):
            continue

        # 获取类中定义的成员函数(排除继承)
        member_funcs = [
            func_name for func_name, func_obj in getmembers(cls, predicate=isfunction)
            if func_obj.__qualname__.startswith(cls.__name__)
        ]

        # 找出不被允许的方法
        disallowed = [
            fn for fn in member_funcs
            if not (fn in allowed_methods or fn.startswith("_"))
        ]

        if disallowed:
            invalid_prompts.append((name, cls.__module__, disallowed))
    errors = []
    # 报告结果
    if invalid_prompts:
        print("\n❌ Check failed, invalid Prompt classes contain disallowed functions:")
        for name, module, funcs in invalid_prompts:
            print(f"- {name} ({module}) disallowed functions: {funcs}")

        # 构造详细错误说明
        rule_explanation = (
            "\nPrompt class naming rule (English):\n"
            "Each Prompt class is only allowed to define the following public methods:\n"
            "  - build_prompt\n"
            "  - build_system_prompt\n"
            "  - __init__\n"
            "Other methods are only allowed if they start with an underscore (_), "
            "indicating they are private helper methods.\n\n"
            "Please check all invalid Prompt classes.\n"
        )

        # 详细列出问题
        details = "\n".join(
            f"  • {name} ({module}) → invalid functions: {funcs}"
            for name, module, funcs in invalid_prompts
        )


        errors.append(
            f"❌ Found {len(invalid_prompts)} Prompt classes violating naming rules.\n"
            f"{rule_explanation}\n"
            f"Details:\n{details}"
        )

    else:
        print("✅ All Prompt class member functions comply with the conventions (only contain allowed functions or private functions)")

    if errors:
        pytest.fail("\n".join(errors), pytrace=False)

if __name__ == "__main__":
    # 全局table,看所有注册的算子的str名称和对应的module路径
    # 获得所有算子的类名2class映射
    # Get the operator map

    test_all_operator_registry()
    exit(0)

    OPERATOR_REGISTRY._get_all()
    print(OPERATOR_REGISTRY)
    # from dataflow.operators.chemistry import ExtractSmilesFromText
    dataflow_obj_map = OPERATOR_REGISTRY.get_obj_map()
    print(OPERATOR_REGISTRY)
    # print count
    print("Total number of OPERATORS:",len(dataflow_obj_map))


    from dataflow.utils.registry import PROMPT_REGISTRY
    print(PROMPT_REGISTRY)


    from dataflow.operators.core_text import PromptedGenerator

    from pprint import pprint
    pprint(OPERATOR_REGISTRY.get_type_of_objects())
    # 因为多个prompt在同一个路径下,所以最后一个module的字段是总的,而非具体prompt的名字。
    pprint(PROMPT_REGISTRY.get_type_of_objects())
    
    # pprint(dataflow_obj_map)
    # # print typedict of all operators
    # print("\nTypedict of all operators:")
    # type_dict = OPERATOR_REGISTRY.get_type_of_operator()
    # pprint(type_dict)   
    # print(len(dataflow_obj_map), "operators registered in the registry.")

    # type_dict_set = set([q for k, v in type_dict.items() for q in v])

    # sum_of_types = sum(len(v) for v in type_dict.values())
    # print(sum_of_types, "operators in total by type dict.")
    # if sum_of_types != len(dataflow_obj_map):
    #     # Raise a warning if the sum of types does not match the total number of operator
    #     # raise Warning("The sum of types does not match the total number of operators.")
    #     # check which one is not matching
    #     print("Mismatch found:")
    #     for key, value in type_dict.items():
    #         for operator in value:
    #             if operator not in dataflow_obj_map:
    #                 raise Warning(f"Operator `{operator}` in type dict  not found in the registry.")
    #     for operator in dataflow_obj_map:
    #         if operator not in type_dict_set:
    #             raise Warning(f"Operator `{operator}` in registry not found in the type dict.")


    # for key, value in type_dict.items():
    #     print(f"{key}: {len(value)} operators")
    # # Check if the map is not empty
    # assert len(dataflow_obj_map) > 0, "No operators found in the registry."

    # # 遍历所有算子,打印其名称和对象,以及init函数和run函数的签名,以及形参列表
    # for name, obj in dataflow_obj_map.items():
    #     # use Blue color for the name
    #     print(f"\033[94mName: {name}, Object {obj}\033[0m")
    #     # get signature of the run and __init__ methods for each operator
    #     if hasattr(obj, 'run'):
    #         run_signature = signature(obj.run)
    #         run_signature_params = run_signature.parameters
    #         # green color for run method
    #         print("\033[92m  run signature: \033[0m")
    #         pprint(run_signature)
    #         print("\033[92m  run signature parameters: \033[0m")
    #         pprint(run_signature_params)
    #     if hasattr(obj, '__init__'):
    #         init_signature = signature(obj.__init__)
    #         init_signature_params = init_signature.parameters
    #         # green color for __init__ method
    #         print("\033[92m  __init__ signature: \033[0m")
    #         pprint(init_signature)
    #         print("\033[92m  __init__ signature parameters: \033[0m")
    #         pprint(init_signature_params)
    #     print()