Commit 6d9e6c0a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Remove unused ProcessManager and TensorTransporter classes from utils… (#312)

Remove unused ProcessManager and TensorTransporter classes from
utils.py; eliminate request_queue initialization in
TorchrunInferenceWorker.
parent 52f3ffc0
...@@ -201,7 +201,6 @@ class TorchrunInferenceWorker: ...@@ -201,7 +201,6 @@ class TorchrunInferenceWorker:
self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.runner = None self.runner = None
self.dist_manager = DistributedManager() self.dist_manager = DistributedManager()
self.request_queue = asyncio.Queue() if self.rank == 0 else None
self.processing = False # Track if currently processing a request self.processing = False # Track if currently processing a request
def init(self, args) -> bool: def init(self, args) -> bool:
......
import base64
import io
import signal
import sys
import psutil
import torch
from PIL import Image
from loguru import logger
from pydantic import BaseModel
class ProcessManager:
@staticmethod
def kill_all_related_processes():
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
logger.info(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
logger.info(f"Failed to kill main process: {e}")
@staticmethod
def signal_handler(sig, frame):
logger.info("\nReceived Ctrl+C, shutting down all related processes...")
ProcessManager.kill_all_related_processes()
sys.exit(0)
@staticmethod
def register_signal_handler():
signal.signal(signal.SIGINT, ProcessManager.signal_handler)
class TaskStatusMessage(BaseModel):
task_id: str
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, list):
return [self.to_device(item, device) for item in data]
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
def prepare_tensor(self, data) -> str:
self.buffer.seek(0)
self.buffer.truncate()
torch.save(self.to_device(data, "cpu"), self.buffer)
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_tensor(self, tensor_base64: str, device="cuda"):
tensor_bytes = base64.b64decode(tensor_base64)
with io.BytesIO(tensor_bytes) as buffer:
return self.to_device(torch.load(buffer), device)
class ImageTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def prepare_image(self, image: Image.Image):
self.buffer.seek(0)
self.buffer.truncate()
image.save(self.buffer, format="PNG")
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_image(self, image_base64: bytes) -> Image.Image:
image_bytes = base64.b64decode(image_base64)
with io.BytesIO(image_bytes) as buffer:
return Image.open(buffer).convert("RGB")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment