test_shm_buffer.py 8.57 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import traceback
import unittest

7
8
import numpy as np

9
from vllm.distributed.device_communicators.shm_object_storage import (
10
11
    SingleWriterShmRingBuffer,
)
12
13
14
15
16
17
18
19
20
21
22
23
24


class TestSingleWriterShmRingBuffer(unittest.TestCase):
    """Test suite for the ring buffer implementation"""

    def setUp(self):
        """Set up test fixtures"""
        self.buffer_size = 4096
        self.ring_buffer = None

    def tearDown(self):
        """Clean up after tests"""
        if self.ring_buffer:
25
            self.ring_buffer.close()
26
27
28
29
30

    def test_buffer_opening(self):
        """Test opening an existing buffer"""
        # First create a buffer
        self.ring_buffer = SingleWriterShmRingBuffer(
31
32
            data_buffer_size=self.buffer_size, create=True
        )
33
34
35
36

        # Then open it with another instance
        reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
        self.assertFalse(reader_buffer.is_writer)
37
38
39
        self.assertEqual(
            reader_buffer.shared_memory.name, self.ring_buffer.shared_memory.name
        )
40
41
42
43

    def test_buffer_access(self):
        """Test accessing allocated buffers"""
        self.ring_buffer = SingleWriterShmRingBuffer(
44
45
            data_buffer_size=self.buffer_size, create=True
        )
46
47
48
49
50
51
52

        size = 100
        address, monotonic_id = self.ring_buffer.allocate_buf(size)

        # Write some test data
        test_data = b"Hello, World!" * 7  # 91 bytes
        with self.ring_buffer.access_buf(address) as (data_buf, metadata):
53
            data_buf[0 : len(test_data)] = test_data
54
55
56

        # Read it back
        with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
57
            read_data = bytes(data_buf2[0 : len(test_data)])
58
59
60
61
62
63
64
65
66
            read_id = metadata2[0]

        self.assertEqual(read_data, test_data)
        self.assertEqual(read_id, monotonic_id)

    def test_memory_error_on_full_buffer(self):
        """Test that MemoryError is raised when buffer is full"""
        small_buffer_size = 200
        self.ring_buffer = SingleWriterShmRingBuffer(
67
68
            data_buffer_size=small_buffer_size, create=True
        )
69
70
71
72
73
74
75
76
77
78
79
80
81

        # Fill up the buffer
        self.ring_buffer.allocate_buf(100)
        self.ring_buffer.allocate_buf(80)  # Total: 196 bytes used

        # This should fail
        with self.assertRaises(MemoryError):
            self.ring_buffer.allocate_buf(1)  # Would exceed buffer capacity

    def test_allocation_and_free(self):
        """Test allocation and freeing of buffers"""
        small_buffer_size = 200
        self.ring_buffer = SingleWriterShmRingBuffer(
82
83
            data_buffer_size=small_buffer_size, create=True
        )
84
85
86
87
88
89
90
91

        size = 80
        # Write some data
        test_data = b"Repeated test data"
        for i in range(5):
            address, monotonic_id = self.ring_buffer.allocate_buf(size)
            with self.ring_buffer.access_buf(address) as (data_buf, metadata):
                data_buf[0:4] = (0).to_bytes(4, "little")  # 0 for not in-use
92
                data_buf[4 : len(test_data) + 4] = test_data
93
94
95
96
97
98
99
100
            print(self.ring_buffer.metadata)
            freed_ids = self.ring_buffer.free_buf(lambda *args: True)
            print(f"  Freed IDs: {freed_ids}")
            self.assertEqual(freed_ids[0], i)

    def test_clear_buffer(self):
        """Test clearing the buffer"""
        self.ring_buffer = SingleWriterShmRingBuffer(
101
102
            data_buffer_size=self.buffer_size, create=True
        )
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

        # Allocate some buffers
        for _ in range(3):
            self.ring_buffer.allocate_buf(100)

        # Clear the buffer
        self.ring_buffer.clear()

        # Check that metadata is empty and IDs reset
        self.assertEqual(len(self.ring_buffer.metadata), 0)
        self.assertEqual(self.ring_buffer.monotonic_id_start, 0)
        self.assertEqual(self.ring_buffer.monotonic_id_end, 0)
        self.assertEqual(self.ring_buffer.data_buffer_start, 0)
        self.assertEqual(self.ring_buffer.data_buffer_end, 0)

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    def test_allocation_cycles(self):
        buffer_size = 100
        ring = SingleWriterShmRingBuffer(data_buffer_size=buffer_size, create=True)

        # tracking allocations for assertions
        allocated_bitmap = np.zeros(
            (buffer_size,), dtype=np.bool_
        )  # addr -> is_allocated
        allocation_map = dict()  # monotonic_id -> (addr, size)

        def count_allocated(bitmap) -> int:
            return np.sum(bitmap).item()

        def is_free_fn(a, b) -> bool:
            return True

        def mark_allocated_with_assertion(id, addr, size):
            addr = addr % buffer_size
            self.assertEqual(count_allocated(allocated_bitmap[addr : addr + size]), 0)

            allocated_bitmap[addr : addr + size] = True
            allocation_map[id] = (addr, size)

        def mark_freed_with_assertion(id):
            self.assertTrue(id in allocation_map)

            addr, size = allocation_map.pop(id)
            addr = addr % buffer_size
            self.assertEqual(
                count_allocated(allocated_bitmap[addr : addr + size]), size
            )

            allocated_bitmap[addr : addr + size] = False

        def ring_free(free_size=None):
            freed_ids = ring.free_buf(is_free_fn, free_size)
            for freed_id in freed_ids:
                mark_freed_with_assertion(freed_id)

        def ring_allocate(allocate_size):
            allocate_size_with_md = allocate_size + ring.MD_SIZE
            try:
                addr, monotonic_id = ring.allocate_buf(allocate_size)
                mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md)
            except MemoryError:
                # free 2x size for enough space if wrapping happened
                ring_free(allocate_size_with_md * 2)

                # retry allocating
                addr, monotonic_id = ring.allocate_buf(allocate_size)
                mark_allocated_with_assertion(monotonic_id, addr, allocate_size_with_md)

        # 1. allocation & free cycles
        for _ in range(33):
            # will consume 2 + 8 = 10 bytes per allocation
            ring_allocate(2)

        # 2. free all allocations
        ring_free()

        # 3. try allocate the largest possible buffer
        ring_allocate(buffer_size - ring.MD_SIZE)

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

def main():
    """Main function demonstrating usage and running tests"""
    print("=== SingleWriterShmRingBuffer Test Suite ===\n")

    # Run unit tests
    print("Running unit tests...")
    unittest.main(argv=[""], exit=False, verbosity=2)

    print("\n" + "=" * 50)
    print("=== Manual Demo ===\n")

    # Manual demonstration
    try:
        print("Creating ring buffer...")
196
        writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048, create=True)
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())

        print(f"Buffer created with name: {writer_buffer.shared_memory.name}")

        # Allocate some buffers
        print("\nAllocating buffers...")
        address_array = []
        for i in range(3):
            size = 100 + i * 50
            try:
                writer_buffer.free_buf(lambda *args: True)
                address, monotonic_id = writer_buffer.allocate_buf(size)
                address_array.append((address, size, monotonic_id))

                # Write some test data
                with writer_buffer.access_buf(address) as (data_buf, metadata):
                    test_message = f"Test message {i}".encode()
214
                    data_buf[0 : len(test_message)] = test_message
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

            except MemoryError as e:
                print(f"  Failed to allocate {size} bytes: {e}")

        print("\nBuffer state:")
        print(f"  Data buffer start: {writer_buffer.data_buffer_start}")
        print(f"  Data buffer end: {writer_buffer.data_buffer_end}")
        print(f"  Monotonic ID start: {writer_buffer.monotonic_id_start}")
        print(f"  Monotonic ID end: {writer_buffer.monotonic_id_end}")
        print(f"  Metadata entries: {len(writer_buffer.metadata)}")

        # Try to read back the data
        print("\nReading back data...")
        for address, size, monotonic_id in address_array:
            with reader_buffer.access_buf(address) as (data_buf, metadata):
                # Find null terminator or read first 50 chars
                data_bytes = bytes(data_buf[0:size])
                message = data_bytes.decode()
                print(f"  ID {monotonic_id}: '{message}'")

    except Exception as e:
        print(f"Demo error: {e}")
        traceback.print_exc()

    print("\n=== Demo Complete ===")


if __name__ == "__main__":
    main()