test_register.py 37.1 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for Helion kernel registration.

Tests ConfiguredHelionKernel, HelionKernelWrapper, and PresetConfigSearch
7
including config picker registration and custom autotuner integration.
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""

from unittest.mock import Mock, patch

import pytest
import torch

from vllm.utils.import_utils import has_helion

if not has_helion():
    pytest.skip(
        "Helion is not installed. Install with: pip install vllm[helion]",
        allow_module_level=True,
    )

import helion
24
import helion.language as hl
25

26
from tests.kernels.helion.helpers import dummy_kernel_registry
27
28
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import (
29
    _HOP_AVAILABLE,
30
31
    ConfiguredHelionKernel,
    HelionKernelWrapper,
32
33
34
    get_kernel_by_name,
    get_registered_kernels,
    register_kernel,
35
36
37
    validate_helion_settings,
)

38
if _HOP_AVAILABLE:
39
    from helion._compat import supports_torch_compile_fusion
40
41
42
    from helion._compiler._dynamo.higher_order_ops import (
        helion_kernel_wrapper_mutation,
    )
43
    from torch._inductor.utils import run_and_get_code
44

45

46
47
48
49
50
51
52
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    for tile in hl.tile(x.size()):
        out[tile] = x[tile] + y[tile]
    return out


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
@pytest.fixture
def sample_configs():
    """Create real Helion config objects for testing."""
    return {
        "hiddensize_4096_batchsize_32": helion.Config(
            block_sizes=[128],
            num_warps=4,
            num_stages=3,
        ),
        "hiddensize_4096_batchsize_64": helion.Config(
            block_sizes=[256],
            num_warps=8,
            num_stages=4,
        ),
        "hiddensize_4096_batchsize_128": helion.Config(
            block_sizes=[512],
            num_warps=16,
            num_stages=2,
        ),
        "default": helion.Config(
            block_sizes=[64],
            num_warps=2,
            num_stages=2,
        ),
    }


@pytest.fixture
def sample_kernel():
    """Create a simple test kernel function."""

    def test_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Simple test kernel that adds two tensors."""
        return x + y

    return test_kernel


@pytest.fixture
def config_manager_with_test_configs(sample_configs):
    """Set up ConfigManager with test configs for nvidia_h200 platform."""
    mock_config_manager = Mock(spec=ConfigManager)
    mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
    return mock_config_manager


@pytest.fixture
def configured_kernel(sample_kernel, sample_configs, config_manager_with_test_configs):
    """Create a ConfiguredHelionKernel for testing."""

    def test_config_picker(args, config_keys):
        """Simple config picker that returns default."""
        return "default"

    with (
        patch(
109
            "vllm.kernels.helion.config_manager.ConfigManager",
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            return_value=config_manager_with_test_configs,
        ),
        patch(
            "vllm.kernels.helion.utils.get_canonical_gpu_name",
            return_value="nvidia_h200",
        ),
        patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
    ):
        # Mock just the helion.kernel decorator to avoid actual kernel compilation
        mock_decorated = Mock()
        mock_kernel.return_value = Mock(return_value=mock_decorated)

        return ConfiguredHelionKernel(
            op_name="test_kernel",
            config_picker=test_config_picker,
            raw_kernel_func=sample_kernel,
            helion_settings=None,
        )


class TestValidateHelionSettings:
    """Test suite for validate_helion_settings utility function."""

    def test_accepts_none_settings(self):
        """Test that None settings are accepted without error."""
        validate_helion_settings(None, "test_kernel")  # Should not raise

    def test_accepts_valid_settings(self):
        """Test that valid settings without conflicts are accepted."""
        settings = helion.Settings()
        settings.static_shapes = False
        settings.print_output_code = True
        validate_helion_settings(settings, "test_kernel")  # Should not raise

    def test_rejects_autotuner_fn(self):
        """Test that settings with custom autotuner_fn raise ValueError."""
        settings = helion.Settings()
        settings.autotuner_fn = lambda *args: None  # Set custom autotuner function

        with pytest.raises(ValueError, match="uses a custom autotuner"):
            validate_helion_settings(settings, "test_kernel")

    def test_warns_on_static_shapes_true(self):
153
        """Test that static_shapes=True emits a warning about being overridden."""
154
155
156
157
158
159
        settings = helion.Settings()
        settings.static_shapes = True

        with patch("vllm.kernels.helion.register.logger") as mock_logger:
            validate_helion_settings(settings, "test_kernel")
            mock_logger.warning.assert_called_once()
160
            assert "overridden to False" in mock_logger.warning.call_args[0][0]
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176


def create_configured_kernel_with_configs(
    op_name,
    config_picker,
    kernel_func,
    configs,
    platform="nvidia_h200",
    helion_settings=None,
):
    """Helper to create ConfiguredHelionKernel with real config objects."""
    mock_config_manager = Mock(spec=ConfigManager)
    mock_config_manager.get_platform_configs = Mock(return_value=configs)

    with (
        patch(
177
            "vllm.kernels.helion.config_manager.ConfigManager",
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            return_value=mock_config_manager,
        ),
        patch(
            "vllm.kernels.helion.utils.get_canonical_gpu_name",
            return_value=platform,
        ),
        patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
    ):
        mock_decorated = Mock()
        mock_kernel.return_value = Mock(return_value=mock_decorated)

        return ConfiguredHelionKernel(
            op_name=op_name,
            config_picker=config_picker,
            raw_kernel_func=kernel_func,
            helion_settings=helion_settings,
        )


class TestConfiguredHelionKernel:
    """Test suite for ConfiguredHelionKernel."""

    def test_init_raises_without_picker(self, sample_kernel, sample_configs):
        """Test that __init__ raises when no picker registered."""
        configs = {"default": sample_configs["default"]}
        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value=configs)

        with (
            patch(
208
                "vllm.kernels.helion.config_manager.ConfigManager",
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            pytest.raises(RuntimeError, match="No config picker registered"),
        ):
            ConfiguredHelionKernel(
                op_name="test_kernel",
                config_picker=None,  # No picker registered
                raw_kernel_func=sample_kernel,
                helion_settings=None,
            )

    def test_config_selector_validates_picker_result(
        self, sample_kernel, sample_configs
    ):
        """Test that config selector validates picker returns valid key."""

        def invalid_picker(args, config_keys):
            return "invalid_key"

        kernel = create_configured_kernel_with_configs(
            op_name="test_kernel",
            config_picker=invalid_picker,
            kernel_func=sample_kernel,
            configs=sample_configs,
        )

        key_computer = kernel._create_key_computer()
        selector = kernel._create_config_selector(key_computer)

        with pytest.raises(
            ValueError, match="Config picker returned invalid config key"
        ):
            selector((torch.randn(32, 4096),))

    def test_config_selector_handles_none_from_picker(
        self, sample_kernel, sample_configs
    ):
        """Test that config selector falls back to 'default' on None."""

        def none_picker(args, config_keys):
            return None

        kernel = create_configured_kernel_with_configs(
            op_name="test_kernel",
            config_picker=none_picker,
            kernel_func=sample_kernel,
            configs=sample_configs,
        )

        key_computer = kernel._create_key_computer()
        selector = kernel._create_config_selector(key_computer)

        result = selector((torch.randn(32, 4096),))
        assert result is kernel.configs["default"]

    def test_create_decorated_kernel_passes_helion_settings(
        self, sample_kernel, sample_configs
    ):
        """Test that _create_decorated_kernel passes helion_settings."""

        def default_picker(args, config_keys):
            return "default"

        settings = helion.Settings()
        settings.print_output_code = True

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)

        with (
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
            patch(
285
                "vllm.kernels.helion.config_manager.ConfigManager",
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
        ):
            mock_decorated = Mock()
            mock_kernel.return_value = Mock(return_value=mock_decorated)

            ConfiguredHelionKernel(
                op_name="test_kernel",
                config_picker=default_picker,
                raw_kernel_func=sample_kernel,
                helion_settings=settings,
            )

            call_kwargs = mock_kernel.call_args[1]
            assert "print_output_code" in call_kwargs
            assert call_kwargs["print_output_code"] is True
306
307
            # static_shapes is always forced to False by vLLM
            assert call_kwargs["static_shapes"] is False
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

    def test_key_and_config_selector_use_same_logic(
        self, sample_kernel, sample_configs
    ):
        """Test that key and config_selector produce identical results."""

        def tracking_picker(args, config_keys):
            x = args[0]
            batch_size = x.shape[0]
            if batch_size <= 32:
                return "hiddensize_4096_batchsize_32"
            elif batch_size <= 64:
                return "hiddensize_4096_batchsize_64"
            return "hiddensize_4096_batchsize_128"

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)

        with (
            patch("vllm.kernels.helion.register.helion.kernel") as mock_helion_kernel,
            patch(
329
                "vllm.kernels.helion.config_manager.ConfigManager",
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
        ):
            mock_decorated = Mock()
            mock_helion_kernel.return_value = Mock(return_value=mock_decorated)

            kernel = ConfiguredHelionKernel(
                op_name="test_kernel",
                config_picker=tracking_picker,
                raw_kernel_func=sample_kernel,
                helion_settings=None,
            )

            call_kwargs = mock_helion_kernel.call_args[1]
            key_fn = call_kwargs["key"]
            autotuner_fn = call_kwargs["autotuner_fn"]

            tensor = torch.randn(50, 4096)  # batch=50, should select batchsize_64

            # key receives unpacked args, autotuner receives args as tuple
            key_result = key_fn(tensor)
            autotuner = autotuner_fn(None, (tensor,))
            config = autotuner.autotune()

            assert key_result == "hiddensize_4096_batchsize_64"
            assert config is kernel.configs["hiddensize_4096_batchsize_64"]


class TestHelionKernelWrapper:
    """Test suite for HelionKernelWrapper."""

365
366
    def test_init_disables_on_missing_configs(self, sample_kernel):
        """Test __init__ marks wrapper as disabled when configs are missing."""
367
368
369
370
371
372
373
374
375
376
377
378
379
380

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(
            return_value={}
        )  # Empty configs

        with (
            patch(
381
                "vllm.kernels.helion.config_manager.ConfigManager",
382
383
384
385
386
387
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
388
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
389
        ):
390
            mock_kernel.return_value = Mock(return_value=sample_kernel)
391

392
393
394
395
396
397
398
399
400
401
402
403
            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )

            assert wrapper._disabled is True
            assert "No configs available" in wrapper._disabled_reason

    def test_disabled_wrapper_raises_on_call(self, sample_kernel):
        """Test __call__ raises RuntimeError on a disabled wrapper."""
404
405
406
407

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

408
409
        def default_picker(args, config_keys):
            return "default"
410
411

        mock_config_manager = Mock(spec=ConfigManager)
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        mock_config_manager.get_platform_configs = Mock(return_value={})

        with (
            patch(
                "vllm.kernels.helion.config_manager.ConfigManager",
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
            mock_kernel.return_value = Mock(return_value=sample_kernel)

            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )

        with pytest.raises(RuntimeError, match="is disabled"):
            wrapper(torch.randn(4, 4), torch.randn(4, 4))

    def test_disabled_wrapper_get_configured_op_raises(self, sample_kernel):
        """Test get_configured_op raises RuntimeError on a disabled wrapper."""

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value={})
448
449
450

        with (
            patch(
451
                "vllm.kernels.helion.config_manager.ConfigManager",
452
453
454
455
456
457
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
458
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
459
        ):
460
461
462
463
464
465
466
467
468
469
            mock_kernel.return_value = Mock(return_value=sample_kernel)

            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )

        with pytest.raises(RuntimeError, match="is disabled"):
470
471
            wrapper.get_configured_op()

472
473
    def test_disabled_wrapper_supports_get_inputs(self, sample_kernel):
        """Test get_inputs works on a disabled wrapper."""
474
475
476
477
478
479
480

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        expected_inputs = {"key1": (torch.randn(4),)}
        input_gen = Mock(return_value=expected_inputs)

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value={})

        with (
            patch(
                "vllm.kernels.helion.config_manager.ConfigManager",
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
            mock_kernel.return_value = Mock(return_value=sample_kernel)

            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
                input_generator=input_gen,
            )

        assert wrapper._disabled is True
        result = wrapper.get_inputs()
        assert result is expected_inputs

    def test_disabled_wrapper_supports_run_autotune(self, sample_kernel):
        """Test run_autotune works on a disabled wrapper."""

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value={})

        mock_config = Mock()

        with (
            patch(
                "vllm.kernels.helion.config_manager.ConfigManager",
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
            mock_kernel.return_value = Mock(return_value=sample_kernel)

            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )

        assert wrapper._disabled is True

        with patch(
            "vllm.kernels.helion.register.create_helion_decorated_kernel"
        ) as mock_create:
            mock_autotune_kernel = Mock()
            mock_autotune_kernel.autotune.return_value = mock_config
            mock_create.return_value = mock_autotune_kernel

            inputs = (torch.randn(4, 4),)
            result = wrapper.run_autotune(inputs)
            assert result is mock_config

    def test_init_caches_configured_kernel(self, sample_kernel, sample_configs):
        """Test __init__ eagerly builds and caches ConfiguredHelionKernel."""

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"
567
568
569
570
571
572

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)

        with (
            patch(
573
                "vllm.kernels.helion.config_manager.ConfigManager",
574
575
576
577
578
579
580
581
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
582
            mock_kernel.return_value = Mock(return_value=sample_kernel)
583

584
585
586
587
588
589
590
591
            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )

            assert wrapper._configured_kernel is not None
592
593
594
            result1 = wrapper.get_configured_op()
            result2 = wrapper.get_configured_op()
            assert result1 is result2
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
    @pytest.mark.skipif(
        not _HOP_AVAILABLE, reason="HOP path only used when HOP available"
    )
    def test_init_eagerly_initializes_hop_path(self):
        """Test that register_kernel eagerly builds the configured kernel
        on the HOP path (no custom op registration needed)."""
        from vllm.kernels.helion.utils import get_canonical_gpu_name

        configs = {"default": helion.Config(block_sizes=[4, 4])}
        with (
            dummy_kernel_registry(configs=configs) as register,
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                wraps=get_canonical_gpu_name,
            ) as mock_gpu,
        ):
            wrapper = register(
                config_picker=lambda args, keys: "default",
            )(_add_kernel)

            mock_gpu.assert_called_once()
            assert wrapper._configured_kernel is not None

        with patch(
            "vllm.kernels.helion.utils.get_canonical_gpu_name",
            side_effect=AssertionError("get_canonical_gpu_name called during __call__"),
        ):
            x = torch.randn(4, 4, device="cuda")
            y = torch.randn(4, 4, device="cuda")
            result = wrapper(x, y)
            expected = x + y
            assert torch.allclose(result, expected)

    @pytest.mark.skipif(
        _HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
    )
    def test_init_eagerly_initializes(self):
        """Test that register_kernel eagerly loads configs and detects GPU
        during construction so __call__ needs no further initialization."""
        from vllm.kernels.helion.utils import get_canonical_gpu_name

        with (
            dummy_kernel_registry() as register,
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                wraps=get_canonical_gpu_name,
            ) as mock_gpu,
        ):
            wrapper = register(
                config_picker=lambda args, keys: "default",
            )(_add_kernel)

            # Init must have detected GPU and built the kernel
            mock_gpu.assert_called_once()
            assert wrapper._configured_kernel is not None
            assert hasattr(torch.ops.vllm_helion, wrapper.op_name)

653
654
655
656
657
658
    @pytest.mark.skipif(
        _HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
    )
    def test_get_or_register_custom_op_returns_cached_op(
        self, sample_kernel, sample_configs
    ):
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)

        existing_op = Mock()
        mock_namespace = Mock()
        mock_namespace.test_kernel = existing_op

        with (
            patch(
674
                "vllm.kernels.helion.config_manager.ConfigManager",
675
676
677
678
679
680
681
682
683
684
685
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch.object(torch.ops, "vllm_helion", mock_namespace),
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
            mock_decorated = Mock()
            mock_kernel.return_value = Mock(return_value=mock_decorated)
686
687
688
689
690
691
692

            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )
693
            result = wrapper._get_or_register_custom_op()
694
695
            assert result is existing_op

696
697
698
699
700
701
    @pytest.mark.skipif(
        _HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
    )
    def test_get_or_register_custom_op_registers_new_op(
        self, sample_kernel, sample_configs
    ):
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        def default_picker(args, config_keys):
            return "default"

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)

        new_op = Mock()
        registered_ops: dict[str, Mock] = {}

        class MockNamespace:
            def __getattr__(self, name):
                if name in registered_ops:
                    return registered_ops[name]
                raise AttributeError(name)

        mock_namespace = MockNamespace()

        def register_side_effect(op_name, op_func, **kwargs):
            registered_ops[op_name] = new_op

        with (
            patch(
727
                "vllm.kernels.helion.config_manager.ConfigManager",
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch.object(torch.ops, "vllm_helion", mock_namespace),
            patch(
                "vllm.kernels.helion.register.direct_register_custom_op",
                side_effect=register_side_effect,
            ) as mock_register,
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
            mock_decorated = Mock()
            mock_kernel.return_value = Mock(return_value=mock_decorated)
743
744
745
746
747
748
749

            wrapper = HelionKernelWrapper(
                raw_kernel_func=sample_kernel,
                op_name="test_kernel",
                fake_impl=fake_impl,
                config_picker=default_picker,
            )
750
            result = wrapper._get_or_register_custom_op()
751
752
753
754

            mock_register.assert_called_once()
            assert result is new_op
            assert mock_register.call_args[1]["op_func"] is mock_decorated
755
756
757
758
759
760


class TestKernelRegistry:
    """Test suite for kernel registry functionality."""

    def setup_method(self):
761
        """Save and clear the registry before each test."""
762
763
        from vllm.kernels.helion.register import _REGISTERED_KERNELS

764
        self._saved_registry = dict(_REGISTERED_KERNELS)
765
766
        _REGISTERED_KERNELS.clear()

767
768
769
770
771
772
773
    def teardown_method(self):
        """Restore the registry after each test."""
        from vllm.kernels.helion.register import _REGISTERED_KERNELS

        _REGISTERED_KERNELS.clear()
        _REGISTERED_KERNELS.update(self._saved_registry)

774
775
776
777
778
779
780
781
782
783
784
785
    def test_get_registered_kernels_returns_copy(self):
        """Test get_registered_kernels returns copy of registry."""
        result1 = get_registered_kernels()
        result2 = get_registered_kernels()

        # Should be separate objects
        assert result1 is not result2
        # Should have same content
        assert result1 == result2

    def test_get_kernel_by_name_returns_kernel(self):
        """Test get_kernel_by_name returns registered kernel."""
786
787
788
789
        with dummy_kernel_registry() as register:
            wrapper = register(
                "test_kernel", config_picker=lambda args, keys: "default"
            )(_add_kernel)
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804

        from vllm.kernels.helion.register import _REGISTERED_KERNELS

        _REGISTERED_KERNELS["test_kernel"] = wrapper

        result = get_kernel_by_name("test_kernel")
        assert result is wrapper

    def test_get_kernel_by_name_returns_none_for_missing(self):
        """Test get_kernel_by_name returns None for missing kernel."""
        result = get_kernel_by_name("nonexistent")
        assert result is None

    def test_register_kernel_auto_generates_fake_impl(self):
        """Test register_kernel auto-generates fake_impl when not provided."""
805
806
807
808
        with (
            dummy_kernel_registry() as register,
            patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer,
        ):
809
810
            mock_fake = Mock()
            mock_infer.return_value = mock_fake
811
812
813
            wrapper = register(
                config_picker=lambda args, keys: "default",
            )(_add_kernel)
814

815
816
        mock_infer.assert_called_once_with(_add_kernel, None)
        assert wrapper._fake_impl is mock_fake
817
818
819

    def test_register_kernel_creates_wrapper(self):
        """Test register_kernel creates HelionKernelWrapper."""
820
821
822
823
        with dummy_kernel_registry() as register:
            result = register("test_name", config_picker=lambda args, keys: "default")(
                _add_kernel
            )
824
825
826

        assert isinstance(result, HelionKernelWrapper)
        assert result.op_name == "test_name"
827
        assert result.raw_kernel_func is _add_kernel
828
829
830

    def test_register_kernel_auto_detects_name(self):
        """Test register_kernel uses function name when no name provided."""
831
832
        with dummy_kernel_registry() as register:
            wrapper = register(config_picker=lambda args, keys: "default")(_add_kernel)
833

834
        assert wrapper.op_name == "_add_kernel"
835
836
837

    def test_register_kernel_registers_in_global_registry(self):
        """Test register_kernel adds wrapper to global registry."""
838
839
840
841
        with dummy_kernel_registry() as register:
            wrapper = register(
                "test_kernel", config_picker=lambda args, keys: "default"
            )(_add_kernel)
842
843
844

        registered_kernels = get_registered_kernels()
        assert "test_kernel" in registered_kernels
845
        assert registered_kernels["test_kernel"] is wrapper
846
847
848

    def test_register_kernel_passes_helion_settings(self):
        """Test register_kernel passes helion_settings to wrapper."""
849
850
        settings = helion.Settings()
        settings.print_output_code = True
851

852
853
854
855
856
857
        with dummy_kernel_registry() as register:
            result = register(
                "test_name",
                config_picker=lambda args, keys: "default",
                helion_settings=settings,
            )(_add_kernel)
858

859
        assert result.helion_settings is settings
860
861
862
863
864

    def test_register_kernel_supports_decorator_syntax(self):
        """Test register_kernel works with decorator arguments."""
        mock_fake = Mock()

865
866
867
868
869
870
        with dummy_kernel_registry() as register:
            result = register(
                "custom_name",
                config_picker=lambda args, keys: "default",
                fake_impl=mock_fake,
            )(_add_kernel)
871
872
873
874
875
876

        assert result.op_name == "custom_name"
        assert result._fake_impl is mock_fake

    def test_register_kernel_raises_on_duplicate_registration(self):
        """Test register_kernel raises error on duplicate names."""
877
878
879
880
        with dummy_kernel_registry() as register:
            register("duplicate_name", config_picker=lambda args, keys: "default")(
                _add_kernel
            )
881

882
883
884
885
            with pytest.raises(ValueError, match="already registered"):
                register("duplicate_name", config_picker=lambda args, keys: "default")(
                    _add_kernel
                )
886
887
888
889
890
891
892
893

    def test_register_kernel_rejects_autotuner_fn_in_settings(self):
        """Test register_kernel rejects conflicting autotuner_fn."""
        mock_settings = Mock()
        mock_settings.to_dict.return_value = {"autotuner_fn": Mock()}

        with pytest.raises(ValueError, match="uses a custom autotuner"):

894
895
896
897
898
            @register_kernel(
                "test",
                config_picker=lambda args, keys: "default",
                helion_settings=mock_settings,
            )
899
900
901
902
903
904
905
906
            def test_kernel(x):
                return x

    def test_register_kernel_no_warning_with_static_shapes_false(self):
        """Test register_kernel doesn't warn with static_shapes=False."""
        mock_settings = Mock()
        mock_settings.to_dict.return_value = {"static_shapes": False}

907
908
909
910
911
912
913
914
915
        with (
            dummy_kernel_registry() as register,
            patch("vllm.kernels.helion.register.logger") as mock_logger,
        ):
            register(
                "test",
                config_picker=lambda args, keys: "default",
                helion_settings=mock_settings,
            )(_add_kernel)
916

917
        mock_logger.warning.assert_not_called()
918

919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    def test_disabled_kernel_appears_in_registry(self):
        """Test that a disabled wrapper is still in the global registry."""

        def fake_impl(*args, **kwargs):
            return torch.zeros_like(args[0])

        mock_config_manager = Mock(spec=ConfigManager)
        mock_config_manager.get_platform_configs = Mock(return_value={})

        with (
            patch(
                "vllm.kernels.helion.config_manager.ConfigManager",
                return_value=mock_config_manager,
            ),
            patch(
                "vllm.kernels.helion.utils.get_canonical_gpu_name",
                return_value="nvidia_h200",
            ),
            patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
        ):
            mock_kernel.return_value = Mock(return_value=_add_kernel)

            wrapper = register_kernel(
                "disabled_kernel",
                config_picker=lambda args, keys: "default",
                fake_impl=fake_impl,
            )(_add_kernel)

        assert wrapper._disabled is True
        registered = get_registered_kernels()
        assert "disabled_kernel" in registered
        assert registered["disabled_kernel"] is wrapper
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007


@pytest.mark.skipif(not _HOP_AVAILABLE, reason="Requires PyTorch >= 2.11 for HOP")
class TestTorchCompileHOP:
    """Test that HelionKernelWrapper emits the correct HOP under torch.compile."""

    def test_compiled_graph_contains_helion_hop(self):
        """Verify torch.compile on a HelionKernelWrapper emits a
        helion_kernel_wrapper_mutation HOP node in the FX graph."""
        configs = {"default": helion.Config(block_sizes=[4, 4])}

        with dummy_kernel_registry(configs=configs) as register:
            add_helion_kernel = register(
                op_name="test_torch_compile_add_kernel",
                config_picker=lambda args, keys: "default",
            )(_add_kernel)

        captured_graph: torch.fx.GraphModule | None = None

        def capturing_backend(gm, example_inputs):
            nonlocal captured_graph
            assert captured_graph is None, "Backend called multiple times"
            captured_graph = gm
            return gm.forward

        def f(x, y):
            return add_helion_kernel(x, y)

        torch._dynamo.reset()
        compiled_f = torch.compile(f, backend=capturing_backend, fullgraph=True)

        x = torch.randn(4, 4, device="cuda")
        y = torch.randn(4, 4, device="cuda")

        # Run compiled version and capture graph
        compiled_result = compiled_f(x, y)

        assert captured_graph is not None
        hop_nodes = [
            node
            for node in captured_graph.graph.nodes
            if node.op == "call_function"
            and node.target is helion_kernel_wrapper_mutation
        ]
        assert len(hop_nodes) > 0, (
            "Expected helion_kernel_wrapper_mutation HOP node in compiled graph, "
            f"but found none. Graph nodes: "
            f"{[(n.op, n.target) for n in captured_graph.graph.nodes]}"
        )

        # Verify compiled result matches eager execution
        eager_result = f(x, y)  # Run in eager mode

        assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), (
            "Compiled execution result doesn't match eager execution. "
            f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
        )
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053

    @pytest.mark.skipif(
        not (_HOP_AVAILABLE and supports_torch_compile_fusion()),
        reason="Requires PyTorch with Helion inductor fusion support",
    )
    def test_inductor_backend_compiles_helion_hop(self):
        """Test torch.compile with inductor backend and Helion fusion enabled."""

        configs = {"default": helion.Config(block_sizes=[4, 4])}

        with dummy_kernel_registry(configs=configs) as register:
            add_helion_kernel = register(
                op_name="test_inductor_add_kernel",
                config_picker=lambda args, keys: "default",
                helion_settings=helion.Settings(
                    torch_compile_fusion=True, static_shapes=False
                ),
            )(_add_kernel)

        def f(x, y):
            x = x * 2.0
            y = y + 1.0
            out = add_helion_kernel(x, y)
            return out.relu()

        torch._dynamo.reset()
        compiled_f = torch.compile(f, backend="inductor", fullgraph=True)

        x = torch.randn(4, 4, device="cuda")
        y = torch.randn(4, 4, device="cuda")

        compiled_result, source_codes = run_and_get_code(compiled_f, x, y)
        eager_result = f(x, y)

        assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), (
            "Inductor-compiled result doesn't match eager execution. "
            f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
        )

        # With fusion enabled, prologue/epilogue ops should be fused into
        # a single triton kernel rather than generating separate kernels.
        kernel_count = sum(code.count("@triton.jit") for code in source_codes)
        assert kernel_count == 1, (
            f"Expected 1 fused triton kernel, got {kernel_count}. "
            "Prologue/epilogue ops were not fused into the Helion kernel."
        )