"tests/test_legacy/test_engine/test_engine.py" did not exist on "62f4e2eb0760ac8bfe28834b061dbc2bda93ade9"
executor.py 418 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch


def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, use_init_ctx=False):
    with torch.cuda.amp.autocast(enabled=enable_autocast):
        if criterion:
            y = model(data)
            loss = criterion(y, label)
        else:
            loss = model(data, label)
        loss = loss.float()
    if use_init_ctx:
        model.backward(loss)
    else:
        loss.backward()