Unverified Commit 14034d83 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Gp/fix api parallel (#567)

parent aee1d657
import os import os
import pickle import pickle
from datetime import timedelta
from typing import Any, Optional from typing import Any, Optional
import torch import torch
...@@ -13,6 +14,7 @@ class DistributedManager: ...@@ -13,6 +14,7 @@ class DistributedManager:
self.rank = 0 self.rank = 0
self.world_size = 1 self.world_size = 1
self.device = "cpu" self.device = "cpu"
self.task_pg = None
CHUNK_SIZE = 1024 * 1024 CHUNK_SIZE = 1024 * 1024
...@@ -26,6 +28,10 @@ class DistributedManager: ...@@ -26,6 +28,10 @@ class DistributedManager:
dist.init_process_group(backend=backend, init_method="env://") dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}") logger.info(f"Setup backend: {backend}")
task_timeout = timedelta(days=30)
self.task_pg = dist.new_group(backend="gloo", timeout=task_timeout)
logger.info("Created gloo process group for task distribution with 30-day timeout")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}" self.device = f"cuda:{self.rank}"
...@@ -51,6 +57,7 @@ class DistributedManager: ...@@ -51,6 +57,7 @@ class DistributedManager:
logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}") logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
finally: finally:
self.is_initialized = False self.is_initialized = False
self.task_pg = None
def barrier(self): def barrier(self):
if self.is_initialized: if self.is_initialized:
...@@ -62,7 +69,7 @@ class DistributedManager: ...@@ -62,7 +69,7 @@ class DistributedManager:
def is_rank_zero(self) -> bool: def is_rank_zero(self) -> bool:
return self.rank == 0 return self.rank == 0
def _broadcast_byte_chunks(self, data_bytes: bytes, device: torch.device) -> None: def _broadcast_byte_chunks(self, data_bytes: bytes) -> None:
total_length = len(data_bytes) total_length = len(data_bytes)
num_full_chunks = total_length // self.CHUNK_SIZE num_full_chunks = total_length // self.CHUNK_SIZE
remaining = total_length % self.CHUNK_SIZE remaining = total_length % self.CHUNK_SIZE
...@@ -71,15 +78,15 @@ class DistributedManager: ...@@ -71,15 +78,15 @@ class DistributedManager:
start_idx = i * self.CHUNK_SIZE start_idx = i * self.CHUNK_SIZE
end_idx = start_idx + self.CHUNK_SIZE end_idx = start_idx + self.CHUNK_SIZE
chunk = data_bytes[start_idx:end_idx] chunk = data_bytes[start_idx:end_idx]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device) task_tensor = torch.tensor(list(chunk), dtype=torch.uint8)
dist.broadcast(task_tensor, src=0) dist.broadcast(task_tensor, src=0, group=self.task_pg)
if remaining: if remaining:
chunk = data_bytes[-remaining:] chunk = data_bytes[-remaining:]
task_tensor = torch.tensor(list(chunk), dtype=torch.uint8).to(device) task_tensor = torch.tensor(list(chunk), dtype=torch.uint8)
dist.broadcast(task_tensor, src=0) dist.broadcast(task_tensor, src=0, group=self.task_pg)
def _receive_byte_chunks(self, total_length: int, device: torch.device) -> bytes: def _receive_byte_chunks(self, total_length: int) -> bytes:
if total_length <= 0: if total_length <= 0:
return b"" return b""
...@@ -88,9 +95,9 @@ class DistributedManager: ...@@ -88,9 +95,9 @@ class DistributedManager:
while remaining > 0: while remaining > 0:
chunk_length = min(self.CHUNK_SIZE, remaining) chunk_length = min(self.CHUNK_SIZE, remaining)
task_tensor = torch.empty(chunk_length, dtype=torch.uint8).to(device) task_tensor = torch.empty(chunk_length, dtype=torch.uint8)
dist.broadcast(task_tensor, src=0) dist.broadcast(task_tensor, src=0, group=self.task_pg)
received.extend(task_tensor.cpu().numpy()) received.extend(task_tensor.numpy())
remaining -= chunk_length remaining -= chunk_length
return bytes(received) return bytes(received)
...@@ -99,46 +106,36 @@ class DistributedManager: ...@@ -99,46 +106,36 @@ class DistributedManager:
if not self.is_initialized: if not self.is_initialized:
return None return None
try:
backend = dist.get_backend() if dist.is_initialized() else "gloo"
except Exception:
backend = "gloo"
if backend == "gloo":
broadcast_device = torch.device("cpu")
else:
broadcast_device = torch.device(self.device if self.device != "cpu" else "cpu")
if self.is_rank_zero(): if self.is_rank_zero():
if task_data is None: if task_data is None:
stop_signal = torch.tensor([1], dtype=torch.int32).to(broadcast_device) stop_signal = torch.tensor([1], dtype=torch.int32)
else: else:
stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device) stop_signal = torch.tensor([0], dtype=torch.int32)
dist.broadcast(stop_signal, src=0) dist.broadcast(stop_signal, src=0, group=self.task_pg)
if task_data is not None: if task_data is not None:
task_bytes = pickle.dumps(task_data) task_bytes = pickle.dumps(task_data)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32).to(broadcast_device) task_length = torch.tensor([len(task_bytes)], dtype=torch.int32)
dist.broadcast(task_length, src=0) dist.broadcast(task_length, src=0, group=self.task_pg)
self._broadcast_byte_chunks(task_bytes, broadcast_device) self._broadcast_byte_chunks(task_bytes)
return task_data return task_data
else: else:
return None return None
else: else:
stop_signal = torch.tensor([0], dtype=torch.int32).to(broadcast_device) stop_signal = torch.tensor([0], dtype=torch.int32)
dist.broadcast(stop_signal, src=0) dist.broadcast(stop_signal, src=0, group=self.task_pg)
if stop_signal.item() == 1: if stop_signal.item() == 1:
return None return None
else: else:
task_length = torch.tensor([0], dtype=torch.int32).to(broadcast_device) task_length = torch.tensor([0], dtype=torch.int32)
dist.broadcast(task_length, src=0) dist.broadcast(task_length, src=0, group=self.task_pg)
total_length = int(task_length.item()) total_length = int(task_length.item())
task_bytes = self._receive_byte_chunks(total_length, broadcast_device) task_bytes = self._receive_byte_chunks(total_length)
task_data = pickle.loads(task_bytes) task_data = pickle.loads(task_bytes)
return task_data return task_data
...@@ -50,6 +50,9 @@ class TorchrunInferenceWorker: ...@@ -50,6 +50,9 @@ class TorchrunInferenceWorker:
return False return False
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]: async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
has_error = False
error_msg = ""
try: try:
if self.world_size > 1 and self.rank == 0: if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data) task_data = self.dist_manager.broadcast_task_data(task_data)
...@@ -71,36 +74,35 @@ class TorchrunInferenceWorker: ...@@ -71,36 +74,35 @@ class TorchrunInferenceWorker:
await asyncio.sleep(0) await asyncio.sleep(0)
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
except Exception as e: except Exception as e:
logger.exception(f"Rank {self.rank} inference failed: {str(e)}") has_error = True
if self.world_size > 1: error_msg = str(e)
self.dist_manager.barrier() logger.exception(f"Rank {self.rank} inference failed: {error_msg}")
if self.rank == 0: if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
if has_error:
return { return {
"task_id": task_data.get("task_id", "unknown"), "task_id": task_data.get("task_id", "unknown"),
"status": "failed", "status": "failed",
"error": str(e), "error": error_msg,
"message": f"Inference failed: {str(e)}", "message": f"Inference failed: {error_msg}",
} }
else: else:
return None return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
async def worker_loop(self): async def worker_loop(self):
while True: while True:
task_data = None
try: try:
task_data = self.dist_manager.broadcast_task_data() task_data = self.dist_manager.broadcast_task_data()
if task_data is None: if task_data is None:
...@@ -111,6 +113,12 @@ class TorchrunInferenceWorker: ...@@ -111,6 +113,12 @@ class TorchrunInferenceWorker:
except Exception as e: except Exception as e:
logger.error(f"Rank {self.rank} worker loop error: {str(e)}") logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
if self.world_size > 1 and task_data is not None:
try:
self.dist_manager.barrier()
except Exception as barrier_error:
logger.error(f"Rank {self.rank} barrier failed after error: {barrier_error}")
break
continue continue
def cleanup(self): def cleanup(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment