"container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch" did not exist on "f38aa46949d8b016b5c4ae559daeb694152d0e40"
precision.py 383 Bytes
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
import torch
from contextlib import suppress


def get_autocast(precision):
    if precision == 'amp':
        return torch.cuda.amp.autocast
    elif precision == 'amp_bfloat16' or precision == 'amp_bf16':
        # amp_bfloat16 is more stable than amp float16 for clip training
        return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
    else:
        return suppress