test_op.py 18.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
import logging
from pathlib import Path
from typing import Any

import pytest
import torch
from torch import fx
from torch.fx.experimental.proxy_tensor import make_fx

import vllm.ir.op
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl

# This should not exist
assert "_custom_add" not in IrOp.registry


class CustomError(Exception):
    pass


@vllm.ir.register_op
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y


def test_registration_overloads():
    assert all(
        n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
    )

    # Calling with decorator
    @vllm.ir.register_op()
    def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x - y

    assert _custom_sub.name == "_custom_sub"
    assert _custom_sub is IrOp.registry["_custom_sub"]

    # Custom name
    @vllm.ir.register_op(name="_custom_mul")
    def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x * y

    assert custom_mul.name == "_custom_mul"
    assert custom_mul is IrOp.registry["_custom_mul"]

    # Direct construction does not register directly
    def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x / y

    custom_div = IrOp("_custom_div", _custom_div)
    assert custom_div.name == "_custom_div"
    assert "_custom_div" not in IrOp.registry

    # Duplicate op registration not allowed
    with pytest.raises(AssertionError):

        @vllm.ir.register_op
        def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return x * y - 100


def test_no_kw_only_args():
    # kw-only args not supported
    with pytest.raises(ValueError, match="keyword-only arguments"):

        @vllm.ir.register_op
        def _custom_kwarg_op(
            x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
        ) -> torch.Tensor:
            return x + y + kwarg

    assert "_custom_kwarg_op" not in IrOp.registry


class TestIrOpCustomAdd:
    # Registration invariants
    def test_decorated_object(self):
        """Make sure that referring directly to an op is correct"""
        assert isinstance(_custom_add, IrOp)
        assert "_custom_add" in IrOp.registry
        assert _custom_add is IrOp.registry["_custom_add"]

    def test_torch_op_is_registered(self):
        assert hasattr(torch.ops.vllm_ir, "_custom_add")
        assert callable(torch.ops.vllm_ir._custom_add.default)

    # Semantic correctness
    def test_semantics_match_native(self):
        x = torch.randn(4, 5)
        y = torch.randn(4, 5)

        # Calls native by default
        out = _custom_add(x, y)
        ref = x + y

        torch.testing.assert_close(out, ref)

    # -------------------------
    # Implementation registration
    # -------------------------

    def test_register_impl_is_non_intrusive(self):
        @_custom_add.register_impl("dummy_provider")
        def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return x + y + 123

        assert "dummy_provider" in _custom_add.impls
        assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)

        x = torch.ones(2, 2)
        y = torch.ones(2, 2)

        # Native semantics must still hold
        torch.testing.assert_close(_custom_add(x, y), x + y)

    def test_schema_contains_tensor_signature(self):
        schema = _custom_add._schema_str

        assert "Tensor" in schema
        assert "-> Tensor" in schema

    # -------------------------
    # FX visibility
    # -------------------------

    @pytest.mark.parametrize("enable_torch_wrap", [True, False])
    @pytest.mark.parametrize("symbolic_trace", [True, False])
    def test_trace_sees_single_custom_op(
        self, symbolic_trace: bool, enable_torch_wrap: bool
    ):
        def fn(x, y):
            return _custom_add(x, y)

        def find_fn(target: Any, gm: fx.GraphModule):
            return gm.graph.find_nodes(op="call_function", target=target)

        with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
            if symbolic_trace:
                gm = torch.fx.symbolic_trace(fn)
            else:
                gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))

            x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
            out_fx = gm(x1, y1)
            out_eager = fn(x1, y1)

            # raise error to check enable_torch_wrap context restored correctly
            raise CustomError

        # check behavior matches eager in all cases
        torch.testing.assert_close(out_fx, out_eager)

        # check that IR nodes only appear if enable_torch_wrap=True
        ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
        if enable_torch_wrap:
            assert len(ir_nodes) == 1, gm.code
        else:
            assert len(ir_nodes) == 0, gm.code

        # with torch wrapping enabled (default), IR nodes appear
        if symbolic_trace:
            gm = torch.fx.symbolic_trace(fn)
        else:
            gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))

        ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
        assert len(ir_nodes) == 1, gm.code


@_custom_add.register_impl("impl_a")
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y + 10


@_custom_add.register_impl("impl_b")
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y + 20


@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y + 50


class TestIrOpImplDispatch:
    def test_register_impl(self):
        assert "impl_a" in _custom_add.impls
        impl = _custom_add.impls["impl_a"]

        assert impl is impl_a
        assert impl.op is _custom_add
        assert impl.provider == "impl_a"
        assert callable(impl.impl_fn)

        # Test duplicate registration rejected
        with pytest.raises(AssertionError):

            @_custom_add.register_impl("impl_a")
            def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y + 30

        # Check the original impl is still intact
        assert _custom_add.impls["impl_a"] is impl_a

        # Check support all args
        assert impl_a.supports_all_args
        assert impl_b.supports_all_args
        assert not impl_even.supports_all_args

    def test_reserved_provider_rejected(self):
        for provider in RESERVED_PROVIDERS:
            with pytest.raises(AssertionError):

                @_custom_add.register_impl(provider)
                def bad_impl(x, y):
                    return x + y

    def test_set_priority_scoped(self):
        assert _custom_add.get_priority() == []

        with _custom_add.set_priority(["impl_even", "impl_b"]):
            assert _custom_add.get_priority() == ["impl_even", "impl_b"]

            # Check nesting
            with _custom_add.set_priority(["impl_b"]):
                assert _custom_add.get_priority() == ["impl_b"]

            # Restored
            assert _custom_add.get_priority() == ["impl_even", "impl_b"]

            # Check that exception restores priority
            with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
                assert _custom_add.get_priority() == ["impl_a"]
                raise CustomError

            # Restored again
            assert _custom_add.get_priority() == ["impl_even", "impl_b"]

        # Restored to empty
        assert _custom_add.get_priority() == []

    def test_dispatch_priority_order(self):
        x = torch.tensor(1, dtype=torch.int32)
        y = torch.tensor(2, dtype=torch.int32)

        with _custom_add.set_priority(["impl_b", "impl_a"]):
            assert _custom_add.dispatch(x, y) is impl_b
            out1 = _custom_add(x, y)
            out2 = torch.ops.vllm_ir._custom_add(x, y)

            with _custom_add.set_priority(["impl_a"]):
                assert _custom_add.dispatch(x, y) is impl_a
                out3 = _custom_add(x, y)
                out4 = torch.ops.vllm_ir._custom_add(x, y)

        # impl_b
        assert out1.item() == 1 + 2 + 20
        assert out2.item() == 1 + 2 + 20
        # impl_a
        assert out3.item() == 1 + 2 + 10
        assert out4.item() == 1 + 2 + 10

    def test_unsupported_impl_filtered(self):
        @_custom_add.register_impl("unsupported", supported=False)
        def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return x + y + 999

        x = torch.tensor(1, dtype=torch.int32)
        y = torch.tensor(2, dtype=torch.int32)

        with _custom_add.set_priority(["unsupported", "impl_a"]):
            assert _custom_add.get_priority() == ["impl_a"]
            out = _custom_add(x, y)

        # impl_bad skipped → impl_a
        assert out.item() == 1 + 2 + 10

    def test_supports_args_runtime_dispatch_and_warning(
        self, caplog_vllm: pytest.LogCaptureFixture
    ):
        x1 = torch.ones((2, 2), dtype=torch.int32)
        y1 = torch.full((2, 2), 2, dtype=torch.int32)

        x2 = torch.ones((2, 3), dtype=torch.int32)
        y2 = torch.full((2, 3), 2, dtype=torch.int32)

        with (
            caplog_vllm.at_level(logging.WARNING),
            _custom_add.set_priority(["impl_even"]),
        ):
            # Test the warning about native fallback is logged (before even dispatching)
            assert len(caplog_vllm.records) == 1
            message = caplog_vllm.records[0].message
            assert "_custom_add" in message
            assert "fallback to native" in message
            assert "priority" in message

            # Check dispatching
            assert _custom_add.get_priority() == ["impl_even", "native"]
            assert _custom_add.dispatch(x1, y1) is impl_even
            assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]

            out1 = _custom_add(x1, y1)  # size(1) == 2 → impl_even
            out2 = _custom_add(x2, y2)  # size(1) == 3 → native fallback

        # no other warnings
        assert len(caplog_vllm.records) == 1
        assert torch.all(out1 == 1 + 2 + 50)
        assert torch.all(out2 == 1 + 2)

    def test_default_priority(
        self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
    ):
        # Make sure logs are not deduplicated to properly test the warning
        x = torch.tensor([3], dtype=torch.int32)
        y = torch.tensor([4], dtype=torch.int32)

        # No priority set → falls back to native
        assert _custom_add.get_priority() == []
        with caplog_vllm.at_level(logging.WARNING):
            # Native by default
            assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
            out = _custom_add(x, y)

        # Check dispatching to native by default
        assert out.item() == 3 + 4

        # Check warning
        assert len(caplog_vllm.records) == 2
        message = caplog_vllm.records[0].message.lower()
        assert "_custom_add" in message
        assert "priority not set" in message


@vllm.ir.register_op
def _custom_mm(
    x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
    tmp = x @ y
    return tmp if bias is None else tmp + bias


def test_default_args():
    # Test that default args are properly applied when dispatching and calling
    @_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
    def impl_mm(
        x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
    ) -> torch.Tensor:
        tmp = x @ y
        return tmp + 50 if bias is None else tmp + bias + 100

    x1 = torch.tensor([1, 2], dtype=torch.int32)
    x2 = torch.tensor([3, 4], dtype=torch.int32)

    # Test that supports_args receives the defaulted args
    assert impl_mm.supports_args(x1, x2)
    with _custom_mm.set_priority(["impl_mm", "native"]):
        assert _custom_mm.dispatch(x1, x2) is impl_mm


def test_bad_impl_registrations():
    # Check bad schema
    with pytest.raises(ValueError, match="does not match native schema"):

        @_custom_mm.register_impl("impl_mm_bad_schema")
        def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            return x @ y - 1

    with pytest.raises(ValueError, match="does not match native schema"):

        @_custom_mm.register_impl("impl_mm_bad_schema_2")
        def impl_mm_bad_schema_2(
            x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
        ) -> torch.Tensor:
            return x @ y + b - 2

    with pytest.raises(ValueError, match="does not match native schema"):

        @_custom_mm.register_impl("impl_mm_bad_schema_3")
        def impl_mm_bad_schema_3(
            x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
        ) -> torch.Tensor:
            return x @ y + bias - 5

    # check supports_args with incorrect params
    with pytest.raises(ValueError, match="supports_args must be a callable"):

        @_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
        def impl_mm_bad_supports_args(
            x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
        ) -> torch.Tensor:
            return x @ y + 10

    with pytest.raises(ValueError, match="number of parameters"):

        @_custom_mm.register_impl(
            "impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
        )
        def impl_mm_bad_supports_args(
            x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
        ) -> torch.Tensor:
            return x @ y + 10

    with pytest.raises(ValueError, match="keyword-only parameters"):

        @_custom_mm.register_impl(
            "impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
        )
        def impl_mm_bad_supports_args_2(
            x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
        ) -> torch.Tensor:
            return x @ y + 20

    with pytest.raises(ValueError, match="does not match native parameter"):

        @_custom_mm.register_impl(
            "impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
        )
        def impl_mm_bad_supports_args_4(
            x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
        ) -> torch.Tensor:
            return x @ y + 30

    with pytest.raises(ValueError, match="does not match native default"):

        @_custom_mm.register_impl(
            "impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
        )
        def impl_mm_bad_supports_args_5(
            x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
        ) -> torch.Tensor:
            return x @ y + 40

    assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}


IMPL_OOT_SRC = """
import torch

@_custom_mm.register_impl("impl_mm_oot")
def impl_mm_oot(
    x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
    return x @ y - 99
"""


def load_custom_mm_module(file_path: Path):
    spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
    assert spec is not None
    module = importlib.util.module_from_spec(spec)

    # Inject the variable into the module's global namespace
    # This allows the @_custom_mm.register_impl decorator to work
    module._custom_mm = _custom_mm  # type: ignore[attr-defined]

    # Execute the file; this triggers the decorator
    assert spec.loader is not None
    spec.loader.exec_module(module)
    return module


def test_uuid_and_oot(tmp_path: Path):
    file_path = tmp_path / "_custom_mm_oot.py"
    file_path.write_text(IMPL_OOT_SRC)

    assert "impl_mm_oot" not in _custom_mm.impls
    _ = load_custom_mm_module(file_path)
    assert "impl_mm_oot" in _custom_mm.impls

    uuid = _custom_mm.impls["impl_mm_oot"].uuid()
    del _custom_mm.impls["impl_mm_oot"]

    # Replace file source
    file_path.write_text(IMPL_OOT_SRC + " # added file source")
    assert "impl_mm_oot" not in _custom_mm.impls
    _ = load_custom_mm_module(file_path)
    assert "impl_mm_oot" in _custom_mm.impls

    uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
    assert uuid1 != uuid
    del _custom_mm.impls["impl_mm_oot"]

    # Back to original
    file_path.write_text(IMPL_OOT_SRC)
    assert "impl_mm_oot" not in _custom_mm.impls
    _ = load_custom_mm_module(file_path)
    assert "impl_mm_oot" in _custom_mm.impls

    uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
    assert uuid2 == uuid
    assert uuid2 != uuid1
    del _custom_mm.impls["impl_mm_oot"]
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562


def _test_native(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y


def _make_op_with_generator(name: str = "_ig_test"):
    op = IrOp(name, _test_native)

    @op.register_input_generator
    def _gen(n: int = 4):
        x = torch.randn(n, 3)
        y = torch.randn(n, 3)
        return x, y

    return op


def _test_native_single(x: torch.Tensor) -> torch.Tensor:
    return x


class TestInputGenerator:
    def test_no_input_generator_by_default(self):
        op = IrOp("_ig_test_no_gen", _test_native_single)
        assert not op.has_input_generator

    def test_register_input_generator(self):
        op = _make_op_with_generator("_ig_test_reg")
        assert op.has_input_generator

    def test_generate_inputs_returns_tuple(self):
        op = _make_op_with_generator("_ig_test_tuple")
        result = op.generate_inputs(n=2)
        assert isinstance(result, tuple)
        assert len(result) == 2
        assert result[0].shape == (2, 3)
        assert result[1].shape == (2, 3)

    def test_generate_inputs_default_kwargs(self):
        op = _make_op_with_generator("_ig_test_default")
        result = op.generate_inputs()
        assert result[0].shape == (4, 3)

    def test_generate_inputs_without_registration_raises(self):
        op = IrOp("_ig_test_no_gen_raises", _test_native_single)
        with pytest.raises(RuntimeError, match="No input generator"):
            op.generate_inputs()


class TestTolerance:
    def test_override_and_get_tolerance(self):
        op = IrOp("_tol_test", _test_native)

        tol = op.get_tolerance(torch.float32)
        assert tol == {"atol": 1e-5, "rtol": 1.3e-6}

        op.override_tolerance(torch.float32, atol=0.1, rtol=0.2)
        assert op.get_tolerance(torch.float32) == {"atol": 0.1, "rtol": 0.2}
        assert op.get_tolerance(torch.float16) == {"atol": 1e-3, "rtol": 1e-3}

    def test_get_tolerance_raises_for_unknown_dtype(self):
        op = IrOp("_tol_test_unknown", _test_native)
        with pytest.raises(ValueError, match="No tolerance defined"):
            op.get_tolerance(torch.complex64)