task_manager.py 7.59 KB
Newer Older
gaclove's avatar
gaclove committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import threading
import uuid
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional

from loguru import logger


class TaskStatus(Enum):
    PENDING = "pending"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"


@dataclass
class TaskInfo:
    task_id: str
    status: TaskStatus
    message: Any
    start_time: datetime = field(default_factory=datetime.now)
    end_time: Optional[datetime] = None
    error: Optional[str] = None
    save_video_path: Optional[str] = None
    stop_event: threading.Event = field(default_factory=threading.Event)
    thread: Optional[threading.Thread] = None


class TaskManager:
    def __init__(self, max_queue_size: int = 100):
        self.max_queue_size = max_queue_size

        self._tasks: OrderedDict[str, TaskInfo] = OrderedDict()
        self._lock = threading.RLock()

        self._processing_lock = threading.Lock()
        self._current_processing_task: Optional[str] = None

        self.total_tasks = 0
        self.completed_tasks = 0
        self.failed_tasks = 0

    def create_task(self, message: Any) -> str:
        with self._lock:
            if hasattr(message, "task_id") and message.task_id in self._tasks:
                raise RuntimeError(f"Task ID {message.task_id} already exists")

            active_tasks = sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
            if active_tasks >= self.max_queue_size:
                raise RuntimeError(f"Task queue is full (max {self.max_queue_size} tasks)")

            task_id = getattr(message, "task_id", str(uuid.uuid4()))
            task_info = TaskInfo(task_id=task_id, status=TaskStatus.PENDING, message=message, save_video_path=getattr(message, "save_video_path", None))

            self._tasks[task_id] = task_info
            self.total_tasks += 1

            self._cleanup_old_tasks()

            return task_id

    def start_task(self, task_id: str) -> TaskInfo:
        with self._lock:
            if task_id not in self._tasks:
                raise KeyError(f"Task {task_id} not found")

            task = self._tasks[task_id]
            task.status = TaskStatus.PROCESSING
            task.start_time = datetime.now()

            self._tasks.move_to_end(task_id)

            return task

    def complete_task(self, task_id: str, save_video_path: Optional[str] = None):
        with self._lock:
            if task_id not in self._tasks:
                logger.warning(f"Task {task_id} not found for completion")
                return

            task = self._tasks[task_id]
            task.status = TaskStatus.COMPLETED
            task.end_time = datetime.now()
            if save_video_path:
                task.save_video_path = save_video_path

            self.completed_tasks += 1

    def fail_task(self, task_id: str, error: str):
        with self._lock:
            if task_id not in self._tasks:
                logger.warning(f"Task {task_id} not found for failure")
                return

            task = self._tasks[task_id]
            task.status = TaskStatus.FAILED
            task.end_time = datetime.now()
            task.error = error

            self.failed_tasks += 1

    def cancel_task(self, task_id: str) -> bool:
        with self._lock:
            if task_id not in self._tasks:
                return False

            task = self._tasks[task_id]

            if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
                return False

            task.stop_event.set()
            task.status = TaskStatus.CANCELLED
            task.end_time = datetime.now()
            task.error = "Task cancelled by user"

            if task.thread and task.thread.is_alive():
                task.thread.join(timeout=5)

            return True

    def cancel_all_tasks(self):
        with self._lock:
            for task_id, task in list(self._tasks.items()):
                if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
                    self.cancel_task(task_id)

    def get_task(self, task_id: str) -> Optional[TaskInfo]:
        with self._lock:
            return self._tasks.get(task_id)

    def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
        task = self.get_task(task_id)
        if not task:
            return None

        return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_video_path": task.save_video_path}

    def get_all_tasks(self):
        with self._lock:
            return {task_id: self.get_task_status(task_id) for task_id in self._tasks}

    def get_active_task_count(self) -> int:
        with self._lock:
            return sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])

    def get_pending_task_count(self) -> int:
        with self._lock:
            return sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)

    def is_processing(self) -> bool:
        with self._lock:
            return self._current_processing_task is not None

    def acquire_processing_lock(self, task_id: str, timeout: Optional[float] = None) -> bool:
        acquired = self._processing_lock.acquire(timeout=timeout if timeout else False)
        if acquired:
            with self._lock:
                self._current_processing_task = task_id
                logger.info(f"Task {task_id} acquired processing lock")
        return acquired

    def release_processing_lock(self, task_id: str):
        with self._lock:
            if self._current_processing_task == task_id:
                self._current_processing_task = None
                try:
                    self._processing_lock.release()
                    logger.info(f"Task {task_id} released processing lock")
                except RuntimeError as e:
                    logger.warning(f"Task {task_id} tried to release lock but failed: {e}")

    def get_next_pending_task(self) -> Optional[str]:
        with self._lock:
            for task_id, task in self._tasks.items():
                if task.status == TaskStatus.PENDING:
                    return task_id
        return None

    def get_service_status(self) -> Dict[str, Any]:
        with self._lock:
            active_tasks = [task_id for task_id, task in self._tasks.items() if task.status == TaskStatus.PROCESSING]

            pending_count = sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)

            return {
                "service_status": "busy" if self._current_processing_task else "idle",
                "current_task": self._current_processing_task,
                "active_tasks": active_tasks,
                "pending_tasks": pending_count,
                "queue_size": self.max_queue_size,
                "total_tasks": self.total_tasks,
                "completed_tasks": self.completed_tasks,
                "failed_tasks": self.failed_tasks,
            }

    def _cleanup_old_tasks(self, keep_count: int = 1000):
        if len(self._tasks) <= keep_count:
            return

        completed_tasks = [(task_id, task) for task_id, task in self._tasks.items() if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]]

        completed_tasks.sort(key=lambda x: x[1].end_time or x[1].start_time)

        remove_count = len(self._tasks) - keep_count
        for task_id, _ in completed_tasks[:remove_count]:
            del self._tasks[task_id]
            logger.debug(f"Cleaned up old task: {task_id}")


task_manager = TaskManager()