test_prompt_restrict.py 9.08 KB
Newer Older
zzg_666's avatar
zzg_666 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from dataflow.core.prompt import PromptABC, DIYPromptABC, prompt_restrict
import pytest
import typing
from typing import get_type_hints, get_origin, get_args, Union
import inspect

# ==== 测试用到的 Prompt 类 ====

class APrompt(PromptABC):
    pass

class BPrompt(PromptABC):
    pass

class OtherPrompt(PromptABC):
    pass

class CustomDIY(DIYPromptABC):
    pass

# ==== 被装饰的“算子/类” ====

@prompt_restrict(APrompt, BPrompt)
class OperatorKwOnly:
    """只从 kwargs 读取 prompt_template 的被测类"""
    def __init__(self, prompt_template=None):
        self.prompt_template = prompt_template

@prompt_restrict(APrompt, BPrompt)
class OperatorWithPositional:
    """用于演示‘位置参数不被检查’的行为"""
    def __init__(self, prompt_template=None):
        self.prompt_template = prompt_template

# ==== 单元测试 ====
@pytest.mark.cpu
def test_accept_allowed_prompt_instances():
    """允许类型(APrompt/BPrompt)的实例应当被接受"""
    op1 = OperatorKwOnly(prompt_template=APrompt())
    op2 = OperatorKwOnly(prompt_template=BPrompt())
    assert isinstance(op1.prompt_template, APrompt)
    assert isinstance(op2.prompt_template, BPrompt)
@pytest.mark.cpu
def test_accept_diy_subclass_instance():
    """任意 DIYPromptABC 的子类实例也应当被接受"""
    op = OperatorKwOnly(prompt_template=CustomDIY())
    assert isinstance(op.prompt_template, CustomDIY)
@pytest.mark.cpu
def test_accept_none():
    """prompt_template=None(省略也等价)应当被接受"""
    op1 = OperatorKwOnly()
    op2 = OperatorKwOnly(prompt_template=None)
    assert op1.prompt_template is None
    assert op2.prompt_template is None
@pytest.mark.cpu
def test_reject_other_prompt_subclass():
    """不是白名单(APrompt/BPrompt)且不是 DIYPromptABC 子类的 PromptABC 子类应当被拒绝"""
    with pytest.raises(TypeError) as ei:
        OperatorKwOnly(prompt_template=OtherPrompt())
    msg = str(ei.value)
    # 友好错误信息:包含类名与白名单
    assert "[OperatorKwOnly] Invalid prompt_template type:" in msg
    assert "OtherPrompt" in msg
    assert "APrompt" in msg and "BPrompt" in msg
@pytest.mark.cpu
def test_reject_non_prompt_object():
    """非 Prompt 类型(例如普通 object 实例)应当被拒绝"""
    with pytest.raises(TypeError):
        OperatorKwOnly(prompt_template=object())
@pytest.mark.cpu
def test_annotations_updated_union():
    """
    __annotations__ / get_type_hints 应该把 prompt_template 暴露为
    Union[APrompt, BPrompt, DIYPromptABC, NoneType]
    """
    hints = get_type_hints(OperatorKwOnly)
    assert "prompt_template" in hints
    anno = hints["prompt_template"]
    origin = get_origin(anno)
    assert origin is typing.Union
    args = set(get_args(anno))
    # 期望四个构成要素
    assert APrompt in args
    assert BPrompt in args
    assert DIYPromptABC in args
    assert type(None) in args
@pytest.mark.cpu
def test_allowed_prompts_attribute():
    """装饰器应当在类上设置 ALLOWED_PROMPTS 为元组"""
    assert hasattr(OperatorKwOnly, "ALLOWED_PROMPTS")
    assert OperatorKwOnly.ALLOWED_PROMPTS == (APrompt, BPrompt)

@pytest.mark.cpu
def test_kwargs_only_behavior_no_check_on_positional():
    """
    当前实现只从 kwargs 获取 prompt_template。
    使用‘位置参数’传入时不会触发检查(这是一个已知行为/潜在坑)。
    下面这行不会抛异常:
    """
    # 尽管 OtherPrompt() 不在白名单里,但因为是位置参数,未被检查
    with pytest.raises(TypeError):
        op = OperatorWithPositional(OtherPrompt())


@pytest.mark.cpu
@pytest.mark.parametrize(
    "value,should_pass",
    [
        (APrompt(), True),
        (BPrompt(), True),
        (CustomDIY(), True),
        (OtherPrompt(), False),
        (object(), False),
    ],
)
def test_matrix_allowed_and_rejected(value, should_pass):
    """参数化覆盖更多输入组合"""
    if should_pass:
        OperatorKwOnly(prompt_template=value)  # 不应抛异常
    else:
        with pytest.raises(TypeError):
            OperatorKwOnly(prompt_template=value)


# ================== 不同签名形式的被测类(prompt_template 的位置不同) ==================

@prompt_restrict(APrompt, BPrompt)
class OpFirst:
    def __init__(self, prompt_template=None, x=0):
        self.prompt_template = prompt_template
        self.x = x

@prompt_restrict(APrompt, BPrompt)
class OpMiddle:
    def __init__(self, x, prompt_template, y=1):
        self.x = x
        self.prompt_template = prompt_template
        self.y = y

@prompt_restrict(APrompt, BPrompt)
class OpKWOnly:
    def __init__(self, *, prompt_template=None, x=0):
        self.prompt_template = prompt_template
        self.x = x

@prompt_restrict(APrompt, BPrompt)
class OpDefaultNone:
    def __init__(self, x=0, y=1, prompt_template=None):
        self.x, self.y = x, y
        self.prompt_template = prompt_template

@prompt_restrict(APrompt, BPrompt)
class OpNoPromptParam:
    def __init__(self, x, y=1):
        self.x, self.y = x, y  # 没有 prompt_template —— 不检查


# ================== 基础通过用例(位置 & 关键字 & None) ==================

@pytest.mark.parametrize("Cls", [OpFirst, OpMiddle, OpKWOnly, OpDefaultNone])
@pytest.mark.parametrize("value", [APrompt(), BPrompt(), CustomDIY(), None])
def test_allowed_values_positional_and_keyword(Cls, value):
    # 先测位置参数(若签名允许)
    sig = inspect.signature(Cls.__init__)
    if "prompt_template" in sig.parameters:
        p = list(sig.parameters).index("prompt_template")
        # 根据 prompt_template 的位置构造实参
        if Cls is OpFirst:
            obj = Cls(value, 123) if value is not None else Cls(None, 123)
        elif Cls is OpMiddle:
            obj = Cls(999, value) if value is not None else Cls(999, None)
        elif Cls is OpDefaultNone:
            # OpDefaultNone: (x=0, y=1, prompt_template=None) —— 允许位置传参
            obj = Cls(1, 2, value)
        else:
            obj = None  # OpKWOnly 不接受位置参数

        if obj is not None:
            assert getattr(obj, "prompt_template") is value

    # 再测关键字参数(所有都支持)
    obj2 = Cls(prompt_template=value) if value is not None else Cls(prompt_template=None)
    assert getattr(obj2, "prompt_template") is value

# ================== 非允许类型:位置 & 关键字都应抛错 ==================

@pytest.mark.parametrize("Cls, args_builder", [
    (OpFirst,      lambda bad: (bad, 0)),          # prompt_template 在首位
    (OpMiddle,     lambda bad: (0, bad)),          # 在中间
    (OpDefaultNone,lambda bad: (0, 1, bad)),       # 在末尾
])
@pytest.mark.parametrize("bad_value", [OtherPrompt(), object()])
def test_disallowed_values_positional_raise(Cls, args_builder, bad_value):
    with pytest.raises(TypeError) as ei:
        Cls(*args_builder(bad_value))
    msg = str(ei.value)
    assert "[{}] Invalid prompt_template type:".format(Cls.__name__) in msg
    assert "APrompt" in msg and "BPrompt" in msg

@pytest.mark.parametrize("Cls", [OpFirst, OpMiddle, OpKWOnly, OpDefaultNone])
@pytest.mark.parametrize("bad_value", [OtherPrompt(), object()])
def test_disallowed_values_keyword_raise(Cls, bad_value):
    with pytest.raises(TypeError):
        Cls(prompt_template=bad_value)

# ================== 位置 + 关键字重复传参:Python 自身错误 ==================

def test_multiple_values_for_argument_python_error():
    with pytest.raises(TypeError):
        OpFirst(APrompt(), prompt_template=BPrompt())  # 同一参数重复提供

# ================== 注解与 ALLOWED_PROMPTS ==================

@pytest.mark.parametrize("Cls", [OpFirst, OpMiddle, OpKWOnly, OpDefaultNone])
def test_annotations_and_allowed_prompts(Cls):
    hints = get_type_hints(Cls)
    assert "prompt_template" in hints
    anno = hints["prompt_template"]
    assert get_origin(anno) is Union
    args = set(get_args(anno))
    assert APrompt in args and BPrompt in args
    assert DIYPromptABC in args and type(None) in args

    assert hasattr(Cls, "ALLOWED_PROMPTS")
    assert Cls.ALLOWED_PROMPTS == (APrompt, BPrompt)

# ================== 继承类行为:应保持相同(位置也检查) ==================

class SubOpFirst(OpFirst): pass
class SubOpMiddle(OpMiddle): pass
class SubOpKWOnly(OpKWOnly): pass

@pytest.mark.parametrize("Cls, args_builder", [
    (SubOpFirst,  lambda bad: (bad, 0)),
    (SubOpMiddle, lambda bad: (0, bad)),
])
@pytest.mark.parametrize("bad_value", [OtherPrompt(), object()])
def test_subclass_positional_still_checked(Cls, args_builder, bad_value):
    with pytest.raises(TypeError):
        Cls(*args_builder(bad_value))

@pytest.mark.parametrize("Cls", [SubOpFirst, SubOpMiddle, SubOpKWOnly])
def test_subclass_keyword_still_checked(Cls):
    with pytest.raises(TypeError):
        Cls(prompt_template=OtherPrompt())

# ================== 边界:没有 prompt_template 参数的类 —— 不检查 ==================

def test_class_without_prompt_param_is_not_checked():
    # 即使传入奇怪对象也不会检查,因为签名不含 prompt_template
    obj = OpNoPromptParam(1, y=2)
    assert obj.x == 1 and obj.y == 2