globals.py 1.42 KB
Newer Older
1
import torch
Nicolas Patry's avatar
Nicolas Patry committed
2
import os
xuxzh1's avatar
last  
xuxzh1 committed
3
4
from loguru import logger
from typing import Dict
5

6
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
Nicolas Patry's avatar
Nicolas Patry committed
7
# This is overridden by the cli
xuxzh1's avatar
last  
xuxzh1 committed
8
9
10
11
12
13
FLASH_DECODING = os.getenv("FLASH_DECODING") in {"1", "true", "True"}
BLOCK_SIZE: int = 256 if FLASH_DECODING else 16
if FLASH_DECODING:
    logger.info("Using FLASH_DECODING")


14
cuda_graphs = os.getenv("CUDA_GRAPHS")
Nicolas Patry's avatar
Nicolas Patry committed
15
if cuda_graphs is not None:
16
17
18
19
20
21
    try:
        cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
    except Exception as e:
        raise RuntimeError(
            f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
        )
22
23
else:
    cuda_graphs = None
xuxzh1's avatar
last  
xuxzh1 committed
24
25
26
27
28
# sorting the cuda graphs in descending order helps reduce the
# memory impact and results in less memory usage
if cuda_graphs is not None:
    cuda_graphs.sort(reverse=True)

29

30
CUDA_GRAPHS = cuda_graphs
xuxzh1's avatar
last  
xuxzh1 committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None


def set_model_id(model_id: str):
    global MODEL_ID
    MODEL_ID = model_id


# NOTE: eventually we should move this into the router and pass back the
# index in all cases.
global ADAPTER_TO_INDEX
ADAPTER_TO_INDEX: Dict[str, int] = None


def set_adapter_to_index(adapter_to_index: Dict[str, int]):
    global ADAPTER_TO_INDEX
    ADAPTER_TO_INDEX = adapter_to_index


def get_adapter_to_index():
    global ADAPTER_TO_INDEX
    return ADAPTER_TO_INDEX