globals.py 877 Bytes
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3
from loguru import logger
4

5
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
Nicolas Patry's avatar
Nicolas Patry committed
6
# This is overridden by the cli
7
cuda_graphs = os.getenv("CUDA_GRAPHS")
Nicolas Patry's avatar
Nicolas Patry committed
8
if cuda_graphs is not None:
9
10
11
12
13
14
    try:
        cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
    except Exception as e:
        raise RuntimeError(
            f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
        )
15
16
17
else:
    cuda_graphs = None

18
19
20
21
22
23
24

# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
    cuda_graphs.sort(reverse=True)


25
CUDA_GRAPHS = cuda_graphs
fxmarty's avatar
fxmarty committed
26
27
28
29
30
31
32
33
34

# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None


def set_model_id(model_id: str):
    global MODEL_ID
    MODEL_ID = model_id