globals.py 185 Bytes
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
3
4

MEM_POOL = torch.cuda.graph_pool_handle()
Nicolas Patry's avatar
Nicolas Patry committed
5
6
# This is overridden by the cli
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}