Commit b5dd15c6 authored by space-nuko's avatar space-nuko
Browse files

System stats endpoint

parent 1bbd3f7f
...@@ -308,6 +308,33 @@ def pytorch_attention_flash_attention(): ...@@ -308,6 +308,33 @@ def pytorch_attention_flash_attention():
return True return True
return False return False
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda + mem_total_torch
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
def get_free_memory(dev=None, torch_free_too=False): def get_free_memory(dev=None, torch_free_too=False):
global xpu_available global xpu_available
global directml_enabled global directml_enabled
......
...@@ -7,6 +7,7 @@ import execution ...@@ -7,6 +7,7 @@ import execution
import uuid import uuid
import json import json
import glob import glob
import torch
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
...@@ -23,6 +24,7 @@ except ImportError: ...@@ -23,6 +24,7 @@ except ImportError:
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
import comfy.utils import comfy.utils
import comfy.model_management
@web.middleware @web.middleware
async def cache_control(request: web.Request, handler): async def cache_control(request: web.Request, handler):
...@@ -280,6 +282,28 @@ class PromptServer(): ...@@ -280,6 +282,28 @@ class PromptServer():
return web.Response(status=404) return web.Response(status=404)
return web.json_response(dt["__metadata__"]) return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
device_index = comfy.model_management.get_torch_device()
device = torch.device(device_index)
device_name = comfy.model_management.get_torch_device_name(device_index)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"devices": [
{
"name": device_name,
"type": device.type,
"index": device.index,
"vram_total": vram_total,
"vram_free": vram_free,
"torch_vram_total": torch_vram_total,
"torch_vram_free": torch_vram_free,
}
]
}
return web.json_response(system_stats)
@routes.get("/prompt") @routes.get("/prompt")
async def get_prompt(request): async def get_prompt(request):
return web.json_response(self.get_queue_info()) return web.json_response(self.get_queue_info())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment