Unverified Commit 14034d83 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Gp/fix api parallel (#567)

parent aee1d657
import os
import pickle
from datetime import timedelta
from typing import Any, Optional
import torch
......@@ -13,6 +14,7 @@ class DistributedManager:
self.rank = 0
self.world_size = 1
self.device = "cpu"
self.task_pg = None
CHUNK_SIZE = 1024 * 1024
......@@ -26,6 +28,10 @@ class DistributedManager:
dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}")
task_timeout = timedelta(days=30)
self.task_pg = dist.new_group(backend="gloo", timeout=task_timeout)
logger.info("Created gloo process group for task distribution with 30-day timeout")
if torch.cuda.is_available():
torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}"
......@@ -51,6 +57,7 @@ class DistributedManager:
logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
finally:
self.is_initialized = False
self.task_pg = None
def barrier(self):
if self.is_initialized:
......@@ -62,7 +69,7 @@ class DistributedManager:
def is_rank_zero(self) -> bool:
return self.rank == 0
def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None:
def _broadcast_byte_chunks(self, data_bytes: bytes) -> None:
total_length = len(data_bytes)
num_full_chunks = total_length // self.CHUNK_SIZE
remaining = total_length % self.CHUNK_SIZE
......@@ -71,15 +78,15 @@ class DistributedManager:
start_idx = i * self.CHUNK_SIZE
end_idx = start_idx + self.CHUNK_SIZE
chunk = data_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8)
dist.broadcast(task_tensor, src=0, group=self.task_pg)
if remaining:
chunk = data_bytes[-remaining:]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8)
dist.broadcast(task_tensor, src=0, group=self.task_pg)
def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes:
def _receive_byte_chunks(self, total_length: int) -> bytes:
if total_length <= 0:
return b""
......@@ -88,9 +95,9 @@ class DistributedManager:
while remaining > 0:
chunk_length = min(self.CHUNK_SIZE, remaining)
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device)
dist.broadcast(task_tensor, src=0)
received.extend(task_tensor.cpu().numpy())
task_tensor = torch.empty(chunk_length, dtype=torch.uint8)
dist.broadcast(task_tensor, src=0, group=self.task_pg)
received.extend(task_tensor.numpy())
remaining -= chunk_length
return bytes(received)
......@@ -99,46 +106,36 @@ class DistributedManager:
if not self.is_initialized:
return None
try:
backend = dist.get_backend() if dist.is_initialized() else "gloo"
except Exception:
backend = "gloo"
if backend == "gloo":
broadcast_device = torch.device("cpu")
else:
broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")
if self.is_rank_zero():
if task_data is None:
stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device)
stop_signal = torch.tensor([1], dtype=torch.int32)
else:
stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
stop_signal = torch.tensor([0], dtype=torch.int32)
dist.broadcast(stop_signal, src=0)
dist.broadcast(stop_signal, src=0, group=self.task_pg)
if task_data is not None:
task_bytes = pickle.dumps(task_data)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32)
dist.broadcast(task_length, src=0)
self._broadcast_byte_chunks(task_bytes, broadcast_device)
dist.broadcast(task_length, src=0, group=self.task_pg)
self._broadcast_byte_chunks(task_bytes)
return task_data
else:
return None
else:
stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
dist.broadcast(stop_signal, src=0)
stop_signal = torch.tensor([0], dtype=torch.int32)
dist.broadcast(stop_signal, src=0, group=self.task_pg)
if stop_signal.item() == 1:
return None
else:
task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device)
task_length = torch.tensor([0], dtype=torch.int32)
dist.broadcast(task_length, src=0)
dist.broadcast(task_length, src=0, group=self.task_pg)
total_length = int(task_length.item())
task_bytes = self._receive_byte_chunks(total_length, broadcast_device)
task_bytes = self._receive_byte_chunks(total_length)
task_data = pickle.loads(task_bytes)
return task_data
......@@ -50,6 +50,9 @@ class TorchrunInferenceWorker:
return False
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
has_error = False
error_msg = ""
try:
if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data)
......@@ -71,36 +74,35 @@ class TorchrunInferenceWorker:
await asyncio.sleep(0)
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
except Exception as e:
logger.exception(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
self.dist_manager.barrier()
has_error = True
error_msg = str(e)
logger.exception(f"Rank {self.rank} inference failed: {error_msg}")
if self.rank == 0:
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
if has_error:
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
"error": error_msg,
"message": f"Inference failed: {error_msg}",
}
else:
return None
return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
async def worker_loop(self):
while True:
task_data = None
try:
task_data = self.dist_manager.broadcast_task_data()
if task_data is None:
......@@ -111,6 +113,12 @@ class TorchrunInferenceWorker:
except Exception as e:
logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
if self.world_size > 1 and task_data is not None:
try:
self.dist_manager.barrier()
except Exception as barrier_error:
logger.error(f"Rank {self.rank} barrier failed after error: {barrier_error}")
break
continue
def cleanup(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment