# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations from typing import cast import pytest from gpu_memory_service.client.memory_manager import GMSClientMemoryManager from gpu_memory_service.client.torch.module import ( materialize_module_from_gms, register_module_tensors, ) from gpu_memory_service.client.torch.tensor import _tensor_from_pointer from gpu_memory_service.common.types import RequestedLockType from tests.gms.harness.gms import GMSServerProcess torch = pytest.importorskip("torch", reason="torch is required") pytestmark = [ pytest.mark.pre_merge, pytest.mark.unit, pytest.mark.gpu_1, ] if not torch.cuda.is_available(): pytest.skip( "CUDA is required for torch GMS integration tests", allow_module_level=True ) class _TinyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(8, 4, bias=False, device="cuda") self.register_buffer( "scale", torch.linspace(0.5, 2.0, steps=4, device="cuda", dtype=torch.float32), ) self.extra = torch.arange(1, 5, device="cuda", dtype=torch.float32) def forward(self, x: torch.Tensor) -> torch.Tensor: y = self.linear(x) y = y + self.scale y = y * self.extra return torch.relu(y) @pytest.fixture def running_gms(request): with GMSServerProcess(request, device=0, tag="weights") as server: yield server.socket_path def _make_gms_tensor( manager: GMSClientMemoryManager, tensor: torch.Tensor, *, tag: str, ) -> tuple[str, torch.Tensor]: storage_bytes = tensor.untyped_storage().nbytes() va = manager.create_mapping(size=storage_bytes, tag=tag) allocation_id = manager.mappings[va].allocation_id gms_tensor = _tensor_from_pointer( va, list(tensor.shape), list(tensor.stride()), tensor.dtype, tensor.device.index or 0, ) gms_tensor.copy_(tensor) return allocation_id, gms_tensor def _assert_exact_tensor_equal(expected: torch.Tensor, actual: torch.Tensor) -> None: torch.testing.assert_close(expected, actual, rtol=0, atol=0) def test_gms_tensor_matches_plain_torch_ops(running_gms): socket_path = running_gms baseline = torch.arange(64, device="cuda", dtype=torch.float32).reshape(8, 8) rhs = torch.arange(32, device="cuda", dtype=torch.float32).reshape(8, 4) writer = GMSClientMemoryManager(socket_path, device=0) writer.connect(RequestedLockType.RW) allocation_id, _writer_tensor = _make_gms_tensor(writer, baseline, tag="weights") assert writer.commit() reader = GMSClientMemoryManager(socket_path, device=0) reader.connect(RequestedLockType.RO) va = reader.create_mapping(allocation_id=allocation_id) gms_tensor = _tensor_from_pointer( va, list(baseline.shape), list(baseline.stride()), baseline.dtype, 0, ) _assert_exact_tensor_equal( torch.relu((baseline + 3.0) @ rhs), torch.relu((gms_tensor + 3.0) @ rhs) ) _assert_exact_tensor_equal( baseline.transpose(0, 1).contiguous(), gms_tensor.transpose(0, 1).contiguous() ) _assert_exact_tensor_equal( baseline[:, 2:6].sum(dim=1), gms_tensor[:, 2:6].sum(dim=1) ) _assert_exact_tensor_equal( (baseline * 2.0 - 5.0).square(), (gms_tensor * 2.0 - 5.0).square() ) reader.close() def test_live_gms_tensor_survives_unmap_and_remap(running_gms): socket_path = running_gms baseline = torch.arange(64, device="cuda", dtype=torch.float32).reshape(8, 8) writer = GMSClientMemoryManager(socket_path, device=0) writer.connect(RequestedLockType.RW) allocation_id, _ = _make_gms_tensor(writer, baseline, tag="weights") assert writer.commit() reader = GMSClientMemoryManager(socket_path, device=0) reader.connect(RequestedLockType.RO) va = reader.create_mapping(allocation_id=allocation_id) gms_tensor = _tensor_from_pointer( va, list(baseline.shape), list(baseline.stride()), baseline.dtype, 0, ) pointer_before = gms_tensor.data_ptr() expected = torch.relu((baseline + 7.0).square()) reader.unmap_all_vas() reader.abort() reader.connect(RequestedLockType.RO) reader.remap_all_vas() assert gms_tensor.data_ptr() == pointer_before _assert_exact_tensor_equal(expected, torch.relu((gms_tensor + 7.0).square())) reader.close() def test_materialized_module_from_gms_matches_plain_module_forward(running_gms): socket_path = running_gms torch.manual_seed(7) baseline = _TinyModule().cuda() gms_model = _TinyModule().cuda() gms_model.load_state_dict(baseline.state_dict()) inputs = torch.randn(3, 8, device="cuda", dtype=torch.float32) expected = baseline(inputs).detach().clone() writer = GMSClientMemoryManager(socket_path, device=0) writer.connect(RequestedLockType.RW) baseline_weight = cast(torch.Tensor, baseline.linear.weight) baseline_scale = cast(torch.Tensor, baseline.scale) baseline_extra = cast(torch.Tensor, baseline.extra) _, gms_weight = _make_gms_tensor(writer, baseline_weight, tag="weights") gms_model.linear.weight = torch.nn.Parameter( gms_weight, requires_grad=baseline_weight.requires_grad ) _, gms_scale = _make_gms_tensor(writer, baseline_scale, tag="weights") gms_model._buffers["scale"] = gms_scale _, gms_extra = _make_gms_tensor(writer, baseline_extra, tag="weights") gms_model.extra = gms_extra register_module_tensors(writer, gms_model) _assert_exact_tensor_equal(expected, gms_model(inputs)) assert writer.commit() reader = GMSClientMemoryManager(socket_path, device=0) reader.connect(RequestedLockType.RO) materialized = _TinyModule().cuda() materialize_module_from_gms(reader, materialized, device_index=0) _assert_exact_tensor_equal(expected, materialized(inputs)) _assert_exact_tensor_equal(baseline_scale, cast(torch.Tensor, materialized.scale)) _assert_exact_tensor_equal(baseline_extra, cast(torch.Tensor, materialized.extra)) _assert_exact_tensor_equal( baseline_weight, cast(torch.Tensor, materialized.linear.weight), ) reader.close()