globals.py 575 Bytes
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3

4
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
Nicolas Patry's avatar
Nicolas Patry committed
5
# This is overridden by the cli
6
cuda_graphs = os.getenv("CUDA_GRAPHS")
7
if torch.cuda.is_available() and cuda_graphs is not None and cuda_graphs != "0":
8
9
10
11
12
13
    try:
        cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
    except Exception as e:
        raise RuntimeError(
            f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
        )
14
15
16
else:
    cuda_graphs = None

17
CUDA_GRAPHS = cuda_graphs