test_weight_transfer.py 29 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for weight transfer engine backends.

Unit tests for engine classes (parsing, validation, registry).
6
Integration tests for NCCL and IPC weight transfer between processes using Ray.
7
8
"""

9
import pickle
10
11
from unittest.mock import MagicMock

12
import pybase64 as base64
13
14
15
import pytest
import ray
import torch
16
from torch.multiprocessing.reductions import reduce_tensor
17
18
19
20

from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
21
22
23
24
25
from vllm.distributed.weight_transfer.ipc_engine import (
    IPCWeightTransferEngine,
    IPCWeightTransferInitInfo,
    IPCWeightTransferUpdateInfo,
)
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLWeightTransferEngine,
    NCCLWeightTransferInitInfo,
    NCCLWeightTransferUpdateInfo,
)
from vllm.utils.network_utils import get_open_port


def create_mock_parallel_config(
    rank: int = 0,
    world_size: int = 1,
    dp_rank: int = 0,
) -> ParallelConfig:
    """Create a mock ParallelConfig for testing."""
    config = MagicMock(spec=ParallelConfig)
    config.rank = rank
    config.world_size = world_size
    config.data_parallel_rank = dp_rank
    return config


# --- Unit Tests: NCCLWeightTransferUpdateInfo Validation ---


class TestNCCLWeightTransferUpdateInfoValidation:
    """Test NCCLWeightTransferUpdateInfo dataclass validation."""

    def test_valid_update_info(self):
        """Test creating valid NCCLWeightTransferUpdateInfo."""
        info = NCCLWeightTransferUpdateInfo(
            names=["layer.weight", "layer.bias"],
            dtype_names=["float32", "float32"],
            shapes=[[10, 10], [10]],
        )
        assert info.names == ["layer.weight", "layer.bias"]
        assert info.dtype_names == ["float32", "float32"]
        assert info.shapes == [[10, 10], [10]]

    def test_mismatched_dtype_names_raises(self):
        """Test that mismatched dtype_names length raises ValueError."""
        with pytest.raises(ValueError, match="dtype_names"):
            NCCLWeightTransferUpdateInfo(
                names=["layer.weight", "layer.bias"],
                dtype_names=["float32"],  # Only one dtype
                shapes=[[10, 10], [10]],
            )

    def test_mismatched_shapes_raises(self):
        """Test that mismatched shapes length raises ValueError."""
        with pytest.raises(ValueError, match="shapes"):
            NCCLWeightTransferUpdateInfo(
                names=["layer.weight", "layer.bias"],
                dtype_names=["float32", "float32"],
                shapes=[[10, 10]],  # Only one shape
            )

    def test_empty_lists_valid(self):
        """Test that empty lists are valid."""
        info = NCCLWeightTransferUpdateInfo(
            names=[],
            dtype_names=[],
            shapes=[],
        )
        assert len(info.names) == 0


# --- Unit Tests: Engine Parsing ---


class TestNCCLEngineParsing:
    """Test NCCLWeightTransferEngine parsing methods."""

    def test_parse_init_info_valid(self):
        """Test parsing valid init info dict."""
        config = WeightTransferConfig(backend="nccl")
        parallel_config = create_mock_parallel_config()
        engine = NCCLWeightTransferEngine(config, parallel_config)

        init_info = engine.parse_init_info(
            {
                "master_address": "127.0.0.1",
                "master_port": 12345,
                "rank_offset": 1,
                "world_size": 3,
            }
        )

        assert isinstance(init_info, NCCLWeightTransferInitInfo)
        assert init_info.master_address == "127.0.0.1"
        assert init_info.master_port == 12345
        assert init_info.rank_offset == 1
        assert init_info.world_size == 3

    def test_parse_init_info_missing_field_raises(self):
        """Test parsing init info with missing required field."""
        config = WeightTransferConfig(backend="nccl")
        parallel_config = create_mock_parallel_config()
        engine = NCCLWeightTransferEngine(config, parallel_config)

        with pytest.raises(ValueError, match="Invalid init_info"):
            engine.parse_init_info(
                {
                    "master_address": "127.0.0.1",
                    # Missing master_port, rank_offset, world_size
                }
            )

    def test_parse_update_info_valid(self):
        """Test parsing valid update info dict."""
        config = WeightTransferConfig(backend="nccl")
        parallel_config = create_mock_parallel_config()
        engine = NCCLWeightTransferEngine(config, parallel_config)

        update_info = engine.parse_update_info(
            {
                "names": ["w1", "w2"],
                "dtype_names": ["float32", "bfloat16"],
                "shapes": [[100, 100], [50]],
            }
        )

        assert isinstance(update_info, NCCLWeightTransferUpdateInfo)
        assert update_info.names == ["w1", "w2"]
        assert update_info.dtype_names == ["float32", "bfloat16"]
        assert update_info.shapes == [[100, 100], [50]]


# --- Unit Tests: Engine Registry ---


class TestEngineRegistry:
    """Test weight transfer engine registry."""

    def test_create_engine_nccl(self):
        """Test factory creates NCCL engine."""
        config = WeightTransferConfig(backend="nccl")
        parallel_config = create_mock_parallel_config()
        engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
        assert isinstance(engine, NCCLWeightTransferEngine)

166
167
168
169
170
171
172
    def test_create_engine_ipc(self):
        """Test factory creates IPC engine."""
        config = WeightTransferConfig(backend="ipc")
        parallel_config = create_mock_parallel_config()
        engine = WeightTransferEngineFactory.create_engine(config, parallel_config)
        assert isinstance(engine, IPCWeightTransferEngine)

173
174
    def test_create_engine_invalid_backend(self):
        """Test factory raises for invalid backend."""
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        # Pydantic validates Literal types at construction, so we can't create
        # a config with an invalid backend. Instead, we test by directly
        # accessing the registry or using model_construct to bypass validation.
        from pydantic import ValidationError

        # Test that Pydantic prevents invalid backend at construction
        with pytest.raises(ValidationError):
            WeightTransferConfig(backend="invalid")

        # Test factory error by creating a config with valid backend but
        # then manually modifying the backend attribute (bypassing validation)
        config = WeightTransferConfig(backend="nccl")
        # Use object.__setattr__ to bypass Pydantic validation
        object.__setattr__(config, "backend", "invalid")
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        parallel_config = create_mock_parallel_config()
        with pytest.raises(ValueError, match="Invalid weight transfer backend"):
            WeightTransferEngineFactory.create_engine(config, parallel_config)

    def test_register_duplicate_raises(self):
        """Test registering duplicate engine name raises."""
        with pytest.raises(ValueError, match="already registered"):
            WeightTransferEngineFactory.register_engine(
                "nccl", NCCLWeightTransferEngine
            )


# --- Test receive_weights without init raises ---


def test_nccl_receive_weights_without_init_raises():
    """Test that receive_weights raises if init_transfer_engine wasn't called."""
206
    if torch.accelerator.device_count() < 1:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        pytest.skip("Need at least 1 GPU for this test")

    config = WeightTransferConfig(backend="nccl")
    parallel_config = create_mock_parallel_config()
    engine = NCCLWeightTransferEngine(config, parallel_config)

    update_info = NCCLWeightTransferUpdateInfo(
        names=["w"],
        dtype_names=["float32"],
        shapes=[[10]],
    )

    with pytest.raises(RuntimeError, match="not initialized"):
        engine.receive_weights(update_info, lambda x: None)


# --- Integration Test: NCCL Weight Transfer Between Ray Tasks ---


@ray.remote(num_gpus=1)
def trainer_broadcast_tensor(
    master_address: str,
    master_port: int,
    world_size: int,
    tensor_shape: list[int],
    tensor_dtype: str,
) -> bool:
    """Trainer task that broadcasts a tensor via NCCL."""
    import torch

    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    from vllm.distributed.utils import StatelessProcessGroup

    # Create process group as rank 0 (trainer)
    pg = StatelessProcessGroup.create(
        host=master_address,
        port=master_port,
        rank=0,
        world_size=world_size,
    )
    # Ray sets CUDA_VISIBLE_DEVICES, so device 0 is the assigned GPU
    comm = PyNcclCommunicator(pg, device=0)

    # Create and broadcast the tensor
    dtype = getattr(torch, tensor_dtype)
    tensor_to_send = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
    comm.broadcast(tensor_to_send, src=0, stream=torch.cuda.current_stream())
254
    torch.accelerator.synchronize()
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

    return True


@ray.remote(num_gpus=1)
def inference_receive_tensor(
    master_address: str,
    master_port: int,
    world_size: int,
    tensor_shape: list[int],
    tensor_dtype: str,
) -> dict:
    """Inference task that receives tensor via NCCLWeightTransferEngine."""
    from unittest.mock import MagicMock

    import torch

    from vllm.config.parallel import ParallelConfig
    from vllm.config.weight_transfer import WeightTransferConfig
    from vllm.distributed.weight_transfer.nccl_engine import (
        NCCLWeightTransferEngine,
        NCCLWeightTransferInitInfo,
        NCCLWeightTransferUpdateInfo,
    )

    # Create engine with mock parallel config
    config = WeightTransferConfig(backend="nccl")
    parallel_config = MagicMock(spec=ParallelConfig)
    parallel_config.rank = 0
    parallel_config.world_size = 1
    parallel_config.data_parallel_rank = 0

    engine = NCCLWeightTransferEngine(config, parallel_config)

    # Initialize the engine (joins as rank 1)
    init_info = NCCLWeightTransferInitInfo(
        master_address=master_address,
        master_port=master_port,
        rank_offset=1,  # Trainer is rank 0, we become rank 1
        world_size=world_size,
    )
    engine.init_transfer_engine(init_info)

    # Receive weights with a no-op load_weights that captures the tensor
    received_tensors = []

    def noop_load_weights(weights: list[tuple[str, torch.Tensor]]):
        for name, tensor in weights:
            # Clone tensor to keep it after engine cleans up
            received_tensors.append((name, tensor.clone()))

    update_info = NCCLWeightTransferUpdateInfo(
        names=["test.weight"],
        dtype_names=[tensor_dtype],
        shapes=[tensor_shape],
    )
    engine.receive_weights(update_info, noop_load_weights)
312
    torch.accelerator.synchronize()
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    # Verify we received the tensor
    success = False
    received_shape = None
    received_sum = None

    if len(received_tensors) == 1:
        name, tensor = received_tensors[0]
        received_shape = list(tensor.shape)
        received_sum = tensor.sum().item()
        # Check shape matches and values are all 1s (trainer sends ones)
        if received_shape == tensor_shape:
            expected_sum = 1.0 * torch.tensor(tensor_shape).prod().item()
            if abs(received_sum - expected_sum) < 0.01:
                success = True

    engine.shutdown()

    return {
        "success": success,
        "received_shape": received_shape,
        "received_sum": received_sum,
    }


@pytest.mark.skipif(
339
    torch.accelerator.device_count() < 2,
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    reason="Need at least 2 GPUs to run NCCL weight transfer test.",
)
def test_nccl_weight_transfer_between_processes():
    """Test NCCL weight transfer from trainer to inference process using Ray.

    This test verifies that the NCCLWeightTransferEngine can receive
    tensors broadcast by a trainer process via NCCL.
    """
    ray.init(ignore_reinit_error=True)

    master_address = "127.0.0.1"
    master_port = get_open_port()
    world_size = 2  # 1 trainer + 1 inference worker

    # Tensor to transfer: 100x100 ones
    tensor_shape = [100, 100]
    tensor_dtype = "float32"

    # Start both tasks concurrently - Ray assigns GPUs automatically
    inference_future = inference_receive_tensor.remote(
        master_address, master_port, world_size, tensor_shape, tensor_dtype
    )
    trainer_future = trainer_broadcast_tensor.remote(
        master_address, master_port, world_size, tensor_shape, tensor_dtype
    )

    # Wait for both to complete
    trainer_result, result = ray.get([trainer_future, inference_future])

    assert trainer_result, "Trainer should complete successfully"
    assert result["success"], (
        f"Weight transfer failed. "
        f"Received shape: {result['received_shape']}, "
        f"Received sum: {result['received_sum']}"
    )
375
376
377
378
379
380
381
382
383
384


# --- Unit Tests: IPCWeightTransferUpdateInfo Validation ---


class TestIPCWeightTransferUpdateInfoValidation:
    """Test IPCWeightTransferUpdateInfo dataclass validation."""

    def test_valid_update_info(self):
        """Test creating valid IPCWeightTransferUpdateInfo."""
385
        if torch.accelerator.device_count() < 1:
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
            pytest.skip("Need at least 1 GPU for this test")

        # Create a dummy tensor and IPC handle
        dummy_tensor = torch.ones(10, 10, device="cuda:0")
        ipc_handle = reduce_tensor(dummy_tensor)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle}]

        info = IPCWeightTransferUpdateInfo(
            names=["layer.weight"],
            dtype_names=["float32"],
            shapes=[[10, 10]],
            ipc_handles=ipc_handles,
        )
        assert info.names == ["layer.weight"]
        assert info.dtype_names == ["float32"]
        assert info.shapes == [[10, 10]]
        assert len(info.ipc_handles) == 1

    def test_mismatched_dtype_names_raises(self):
        """Test that mismatched dtype_names length raises ValueError."""
407
        if torch.accelerator.device_count() < 1:
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
            pytest.skip("Need at least 1 GPU for this test")

        dummy_tensor = torch.ones(10, 10, device="cuda:0")
        ipc_handle = reduce_tensor(dummy_tensor)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]

        with pytest.raises(ValueError, match="dtype_names"):
            IPCWeightTransferUpdateInfo(
                names=["layer.weight", "layer.bias"],
                dtype_names=["float32"],  # Only one dtype
                shapes=[[10, 10], [10]],
                ipc_handles=ipc_handles,
            )

    def test_mismatched_shapes_raises(self):
        """Test that mismatched shapes length raises ValueError."""
425
        if torch.accelerator.device_count() < 1:
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
            pytest.skip("Need at least 1 GPU for this test")

        dummy_tensor = torch.ones(10, 10, device="cuda:0")
        ipc_handle = reduce_tensor(dummy_tensor)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle}, {gpu_uuid: ipc_handle}]

        with pytest.raises(ValueError, match="shapes"):
            IPCWeightTransferUpdateInfo(
                names=["layer.weight", "layer.bias"],
                dtype_names=["float32", "float32"],
                shapes=[[10, 10]],  # Only one shape
                ipc_handles=ipc_handles,
            )

    def test_mismatched_ipc_handles_raises(self):
        """Test that mismatched ipc_handles length raises ValueError."""
443
        if torch.accelerator.device_count() < 1:
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            pytest.skip("Need at least 1 GPU for this test")

        dummy_tensor = torch.ones(10, 10, device="cuda:0")
        ipc_handle = reduce_tensor(dummy_tensor)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle}]  # Only one handle

        with pytest.raises(ValueError, match="ipc_handles"):
            IPCWeightTransferUpdateInfo(
                names=["layer.weight", "layer.bias"],
                dtype_names=["float32", "float32"],
                shapes=[[10, 10], [10]],
                ipc_handles=ipc_handles,
            )

459
    def test_valid_update_info_from_pickled(self, monkeypatch):
460
        """Test creating IPCWeightTransferUpdateInfo from pickled handles."""
461
        if torch.accelerator.device_count() < 1:
462
463
            pytest.skip("Need at least 1 GPU for this test")

464
465
        monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
        dummy_tensor = torch.ones(10, 10, device="cuda:0")
        ipc_handle = reduce_tensor(dummy_tensor)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle}]

        pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")

        info = IPCWeightTransferUpdateInfo(
            names=["layer.weight"],
            dtype_names=["float32"],
            shapes=[[10, 10]],
            ipc_handles_pickled=pickled,
        )
        assert info.ipc_handles == ipc_handles
        assert info.ipc_handles_pickled is None

482
483
484
485
486
487
488
489
490
491
492
493
    def test_pickled_requires_insecure_serialization_flag(self, monkeypatch):
        """Test that pickled handles are rejected unless env flag is enabled."""
        monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")

        with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"):
            IPCWeightTransferUpdateInfo(
                names=[],
                dtype_names=[],
                shapes=[],
                ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"),
            )

494
495
    def test_both_handles_and_pickled_raises(self):
        """Test that providing both ipc_handles and ipc_handles_pickled raises."""
496
        if torch.accelerator.device_count() < 1:
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
            pytest.skip("Need at least 1 GPU for this test")

        dummy_tensor = torch.ones(10, 10, device="cuda:0")
        ipc_handle = reduce_tensor(dummy_tensor)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle}]

        pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")

        with pytest.raises(ValueError, match="Cannot specify both"):
            IPCWeightTransferUpdateInfo(
                names=["layer.weight"],
                dtype_names=["float32"],
                shapes=[[10, 10]],
                ipc_handles=ipc_handles,
                ipc_handles_pickled=pickled,
            )

    def test_neither_handles_nor_pickled_raises(self):
        """Test that providing neither ipc_handles nor ipc_handles_pickled raises."""
        with pytest.raises(ValueError, match="must be provided"):
            IPCWeightTransferUpdateInfo(
                names=["layer.weight"],
                dtype_names=["float32"],
                shapes=[[10, 10]],
            )

    def test_empty_lists_valid(self):
        """Test that empty lists are valid."""
        info = IPCWeightTransferUpdateInfo(
            names=[],
            dtype_names=[],
            shapes=[],
            ipc_handles=[],
        )
        assert len(info.names) == 0


# --- Unit Tests: IPC Engine Parsing ---


class TestIPCEngineParsing:
    """Test IPCWeightTransferEngine parsing methods."""

    def test_parse_update_info_valid(self):
        """Test parsing valid update info dict."""
543
        if torch.accelerator.device_count() < 1:
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
            pytest.skip("Need at least 1 GPU for this test")

        config = WeightTransferConfig(backend="ipc")
        parallel_config = create_mock_parallel_config()
        engine = IPCWeightTransferEngine(config, parallel_config)

        # Create dummy IPC handles
        dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
        dummy_tensor2 = torch.ones(50, device="cuda:0")
        ipc_handle1 = reduce_tensor(dummy_tensor1)
        ipc_handle2 = reduce_tensor(dummy_tensor2)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]

        update_info = engine.parse_update_info(
            {
                "names": ["w1", "w2"],
                "dtype_names": ["float32", "bfloat16"],
                "shapes": [[100, 100], [50]],
                "ipc_handles": ipc_handles,
            }
        )

        assert isinstance(update_info, IPCWeightTransferUpdateInfo)
        assert update_info.names == ["w1", "w2"]
        assert update_info.dtype_names == ["float32", "bfloat16"]
        assert update_info.shapes == [[100, 100], [50]]
        assert len(update_info.ipc_handles) == 2

573
    def test_parse_update_info_pickled(self, monkeypatch):
574
        """Test parsing update info with pickled IPC handles (HTTP path)."""
575
        if torch.accelerator.device_count() < 1:
576
577
            pytest.skip("Need at least 1 GPU for this test")

578
579
        monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        config = WeightTransferConfig(backend="ipc")
        parallel_config = create_mock_parallel_config()
        engine = IPCWeightTransferEngine(config, parallel_config)

        dummy_tensor1 = torch.ones(100, 100, device="cuda:0")
        dummy_tensor2 = torch.ones(50, device="cuda:0")
        ipc_handle1 = reduce_tensor(dummy_tensor1)
        ipc_handle2 = reduce_tensor(dummy_tensor2)
        gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
        ipc_handles = [{gpu_uuid: ipc_handle1}, {gpu_uuid: ipc_handle2}]

        pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")

        update_info = engine.parse_update_info(
            {
                "names": ["w1", "w2"],
                "dtype_names": ["float32", "bfloat16"],
                "shapes": [[100, 100], [50]],
                "ipc_handles_pickled": pickled,
            }
        )

        assert isinstance(update_info, IPCWeightTransferUpdateInfo)
        assert update_info.names == ["w1", "w2"]
        assert len(update_info.ipc_handles) == 2
        assert update_info.ipc_handles_pickled is None
        assert gpu_uuid in update_info.ipc_handles[0]
        assert gpu_uuid in update_info.ipc_handles[1]


# --- Integration Test: IPC Weight Transfer Between Ray Tasks ---


def get_physical_gpu_id(device_index: int = 0) -> str:
    """Get physical GPU UUID for a device."""
    props = torch.cuda.get_device_properties(device_index)
    return str(props.uuid)


@ray.remote(num_gpus=0.5)
class TrainerActor:
    """Trainer actor that creates and holds CUDA IPC handles."""

    def __init__(self, tensor_shape: list[int], tensor_dtype: str):
        # Create tensor on GPU and keep it alive
        dtype = getattr(torch, tensor_dtype)
        self.tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda:0")
        self.tensor.fill_(42.0)  # Fill with 42 to verify correct transfer

        # Create IPC handle (tensor must stay alive for IPC to work)
        ipc_handle = reduce_tensor(self.tensor)
        gpu_uuid = get_physical_gpu_id(0)

633
        torch.accelerator.synchronize()
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706

        self.ipc_handle_dict = {
            "ipc_handle": ipc_handle,
            "gpu_uuid": gpu_uuid,
            "shape": tensor_shape,
            "dtype": tensor_dtype,
        }

    def get_ipc_handle_dict(self) -> dict:
        """Return IPC handle dict. Tensor stays alive in this actor."""
        return self.ipc_handle_dict


@ray.remote(num_gpus=0.5)
def inference_receive_ipc_tensor(
    ipc_handle_dict: dict,
    mode: str = "ray",
) -> dict:
    """Inference task that receives tensor via IPCWeightTransferEngine."""
    from unittest.mock import MagicMock

    import torch

    from vllm.config.parallel import ParallelConfig
    from vllm.config.weight_transfer import WeightTransferConfig
    from vllm.distributed.weight_transfer.ipc_engine import (
        IPCWeightTransferEngine,
    )

    # Create engine with mock parallel config
    config = WeightTransferConfig(backend="ipc")
    parallel_config = MagicMock(spec=ParallelConfig)
    parallel_config.rank = 0
    parallel_config.world_size = 1
    parallel_config.data_parallel_rank = 0

    engine = IPCWeightTransferEngine(config, parallel_config)

    # Initialize the engine (no-op for IPC)
    init_info = IPCWeightTransferInitInfo()
    engine.init_transfer_engine(init_info)

    # Receive weights with a no-op load_weights that captures the tensor
    received_tensors = []

    def noop_load_weights(weights: list[tuple[str, torch.Tensor]]):
        for name, tensor in weights:
            # Clone tensor to keep it after engine cleans up
            received_tensors.append((name, tensor.clone()))

    # Build update dict and go through parse_update_info (exercises __post_init__)
    ipc_handles = [{ipc_handle_dict["gpu_uuid"]: ipc_handle_dict["ipc_handle"]}]

    if mode == "ray":
        update_dict: dict = {
            "names": ["test.weight"],
            "dtype_names": [ipc_handle_dict["dtype"]],
            "shapes": [ipc_handle_dict["shape"]],
            "ipc_handles": ipc_handles,
        }
    elif mode == "http":
        pickled = base64.b64encode(pickle.dumps(ipc_handles)).decode("utf-8")
        update_dict = {
            "names": ["test.weight"],
            "dtype_names": [ipc_handle_dict["dtype"]],
            "shapes": [ipc_handle_dict["shape"]],
            "ipc_handles_pickled": pickled,
        }
    else:
        raise ValueError(f"Unknown mode: {mode}")

    update_info = engine.parse_update_info(update_dict)
    engine.receive_weights(update_info, noop_load_weights)
707
    torch.accelerator.synchronize()
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733

    # Verify we received the tensor
    success = False
    received_shape = None
    received_sum = None

    if len(received_tensors) == 1:
        name, tensor = received_tensors[0]
        received_shape = list(tensor.shape)
        received_sum = tensor.sum().item()
        # Check shape matches and values are all 42s (trainer sends 42s)
        if received_shape == ipc_handle_dict["shape"]:
            expected_sum = 42.0 * torch.tensor(ipc_handle_dict["shape"]).prod().item()
            if abs(received_sum - expected_sum) < 0.01:
                success = True

    engine.shutdown()

    return {
        "success": success,
        "received_shape": received_shape,
        "received_sum": received_sum,
    }


@pytest.mark.skipif(
734
    torch.accelerator.device_count() < 1,
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
    reason="Need at least 1 GPU to run IPC weight transfer test.",
)
@pytest.mark.parametrize("mode", ["ray", "http"])
def test_ipc_weight_transfer_between_processes(mode: str):
    """Test IPC weight transfer from trainer to inference process using Ray.

    Parametrized over transport modes:
    - 'ray':  ipc_handles passed directly.
    - 'http': ipc_handles pickled + base64-encoded, unpickled via __post_init__.

    IPC requires same-GPU access, so we use a placement group to co-locate
    the trainer actor and inference task on the same GPU.
    """
    from ray.util.placement_group import placement_group
    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

    ray.init(ignore_reinit_error=True)

    # Create a placement group to ensure both processes are on the same GPU
    # Use fractional GPUs so both tasks can share the same GPU bundle
    pg = placement_group([{"GPU": 1, "CPU": 2}])
    ray.get(pg.ready())

    scheduling_strategy = PlacementGroupSchedulingStrategy(
        placement_group=pg,
        placement_group_capture_child_tasks=True,
    )

    # Tensor to transfer: 100x100 filled with 42s
    tensor_shape = [100, 100]
    tensor_dtype = "float32"

    # Create trainer actor that holds the tensor and IPC handle (stays alive)
    trainer_actor = TrainerActor.options(  # type: ignore[attr-defined]
        scheduling_strategy=scheduling_strategy
    ).remote(tensor_shape, tensor_dtype)

    # Get IPC handle dict (tensor stays alive in trainer actor)
    ipc_handle_dict = ray.get(trainer_actor.get_ipc_handle_dict.remote())

    # Receive tensor in inference process using IPC handles (on same GPU)
    # Trainer actor stays alive during this operation
    inference_result = ray.get(
        inference_receive_ipc_tensor.options(
            scheduling_strategy=scheduling_strategy
        ).remote(ipc_handle_dict, mode=mode)
    )

    assert inference_result["success"], (
        f"IPC weight transfer failed (mode={mode}). "
        f"Received shape: {inference_result['received_shape']}, "
        f"Received sum: {inference_result['received_sum']}"
    )


def test_ipc_receive_weights_missing_gpu_uuid_raises():
    """Test that receive_weights raises if GPU UUID not found in IPC handles."""
792
    if torch.accelerator.device_count() < 1:
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        pytest.skip("Need at least 1 GPU for this test")

    config = WeightTransferConfig(backend="ipc")
    parallel_config = create_mock_parallel_config()
    engine = IPCWeightTransferEngine(config, parallel_config)

    # Create IPC handle with wrong GPU UUID
    dummy_tensor = torch.ones(10, 10, device="cuda:0")
    ipc_handle = reduce_tensor(dummy_tensor)
    wrong_uuid = "wrong-uuid-12345"
    ipc_handles = [{wrong_uuid: ipc_handle}]

    update_info = IPCWeightTransferUpdateInfo(
        names=["w"],
        dtype_names=["float32"],
        shapes=[[10, 10]],
        ipc_handles=ipc_handles,
    )

    with pytest.raises(ValueError, match="IPC handle not found"):
        engine.receive_weights(update_info, lambda x: None)