test_s3_work_queue.py 9.34 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import asyncio
import datetime
import hashlib
import unittest
from typing import Dict, List
from unittest.mock import Mock, call, patch

from botocore.exceptions import ClientError

# Import the classes we're testing
from olmocr.work_queue import S3WorkQueue, WorkItem


class TestS3WorkQueue(unittest.TestCase):
    def setUp(self):
        """Set up test fixtures before each test method."""
        self.s3_client = Mock()
        self.s3_client.exceptions.ClientError = ClientError
        self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
        self.sample_paths = [
            "s3://test-bucket/data/file1.pdf",
            "s3://test-bucket/data/file2.pdf",
            "s3://test-bucket/data/file3.pdf",
        ]

    def tearDown(self):
        """Clean up after each test method."""
        pass

    def test_compute_workgroup_hash(self):
        """Test hash computation is deterministic and correct"""
        paths = [
            "s3://test-bucket/data/file2.pdf",
            "s3://test-bucket/data/file1.pdf",
        ]

        # Hash should be the same regardless of order
        hash1 = S3WorkQueue._compute_workgroup_hash(paths)
        hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
        self.assertEqual(hash1, hash2)

    def test_init(self):
        """Test initialization of S3WorkQueue"""
        client = Mock()
        queue = S3WorkQueue(client, "s3://test-bucket/workspace/")

        self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
        self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
        self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")

    def asyncSetUp(self):
        """Set up async test fixtures"""
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

    def asyncTearDown(self):
        """Clean up async test fixtures"""
        self.loop.close()

    def async_test(f):
        """Decorator for async test methods"""

        def wrapper(*args, **kwargs):
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                return loop.run_until_complete(f(*args, **kwargs))
            finally:
                loop.close()

        return wrapper

    @async_test
    async def test_populate_queue_new_items(self):
        """Test populating queue with new items"""
        # Mock empty existing index
        with patch("olmocr.work_queue.download_zstd_csv", return_value=[]):
            with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
                await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)

                # Verify upload was called with correct data
                self.assertEqual(mock_upload.call_count, 1)
                _, _, lines = mock_upload.call_args[0]

                # Should create 2 work groups (2 files + 1 file)
                self.assertEqual(len(lines), 2)

                # Verify format of uploaded lines
                for line in lines:
                    parts = line.split(",")
                    self.assertGreaterEqual(len(parts), 2)  # Hash + at least one path
                    self.assertEqual(len(parts[0]), 40)  # SHA1 hash length

    @async_test
    async def test_populate_queue_existing_items(self):
        """Test populating queue with mix of new and existing items"""
        existing_paths = ["s3://test-bucket/data/existing1.pdf"]
        new_paths = ["s3://test-bucket/data/new1.pdf"]

        # Create existing index content
        existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
        existing_line = f"{existing_hash},{existing_paths[0]}"

        with patch("olmocr.work_queue.download_zstd_csv", return_value=[existing_line]):
            with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
                await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)

                # Verify upload called with both existing and new items
                _, _, lines = mock_upload.call_args[0]
                self.assertEqual(len(lines), 2)
                self.assertIn(existing_line, lines)

    @async_test
    async def test_initialize_queue(self):
        """Test queue initialization"""
        # Mock work items and completed items
        work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
        work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
        work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"

        completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]

        with patch("olmocr.work_queue.download_zstd_csv", return_value=[work_line]):
            with patch("olmocr.work_queue.expand_s3_glob", return_value=completed_items):
                await self.work_queue.initialize_queue()

                # Queue should be empty since all work is completed
                self.assertTrue(self.work_queue._queue.empty())

    @async_test
    async def test_is_completed(self):
        """Test completed work check"""
        work_hash = "testhash123"

        # Test completed work
        self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
        self.assertTrue(await self.work_queue.is_completed(work_hash))

        # Test incomplete work
        self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
        self.assertFalse(await self.work_queue.is_completed(work_hash))

    @async_test
    async def test_get_work(self):
        """Test getting work items"""
        # Setup test data
        work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
        await self.work_queue._queue.put(work_item)

        # Test getting available work
        self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
        result = await self.work_queue.get_work()
        self.assertEqual(result, work_item)

        # Verify lock file was created
        self.s3_client.put_object.assert_called_once()
        bucket, key = self.s3_client.put_object.call_args[1]["Bucket"], self.s3_client.put_object.call_args[1]["Key"]
        self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))

    @async_test
    async def test_get_work_completed(self):
        """Test getting work that's already completed"""
        work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
        await self.work_queue._queue.put(work_item)

        # Simulate completed work
        self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}

        result = await self.work_queue.get_work()
        self.assertIsNone(result)  # Should skip completed work

    @async_test
    async def test_get_work_locked(self):
        """Test getting work that's locked by another worker"""
        work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
        await self.work_queue._queue.put(work_item)

        # Simulate active lock
        recent_time = datetime.datetime.now(datetime.timezone.utc)
        self.s3_client.head_object.side_effect = [
            ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"),  # Not completed
            {"LastModified": recent_time},  # Active lock
        ]

        result = await self.work_queue.get_work()
        self.assertIsNone(result)  # Should skip locked work

    @async_test
    async def test_get_work_stale_lock(self):
        """Test getting work with a stale lock"""
        work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
        await self.work_queue._queue.put(work_item)

        # Simulate stale lock
        stale_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
        self.s3_client.head_object.side_effect = [
            ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"),  # Not completed
            {"LastModified": stale_time},  # Stale lock
        ]

        result = await self.work_queue.get_work()
        self.assertEqual(result, work_item)  # Should take work with stale lock

    @async_test
    async def test_mark_done(self):
        """Test marking work as done"""
        work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
        await self.work_queue._queue.put(work_item)

        await self.work_queue.mark_done(work_item)

        # Verify lock file was deleted
        self.s3_client.delete_object.assert_called_once()
        bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
        self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))

    def test_queue_size(self):
        """Test queue size property"""
        self.assertEqual(self.work_queue.size, 0)

        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(self.loop)

        self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"])))
        self.assertEqual(self.work_queue.size, 1)

        self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"])))
        self.assertEqual(self.work_queue.size, 2)

        self.loop.close()


if __name__ == "__main__":
    unittest.main()