"vscode:/vscode.git/clone" did not exist on "d41a9f12c6300224aa5671c35fc6cd3f786d54b5"
executor.py 673 Bytes
Newer Older
1
2
3
import torch


HELSON's avatar
HELSON committed
4
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
    """run_fwd_bwd
    run fwd and bwd for the model

    Args:
        model (torch.nn.Module): a PyTorch model
        data (torch.Tensor): input data
        label (torch.Tensor): label
        criterion (Optional[Callable]): a function of criterion

    Returns:
        torch.Tensor: loss of fwd
    """
    if criterion:
        y = model(data)
        y = y.float()
        loss = criterion(y, label)
    else:
        loss = model(data, label)

    loss = loss.float()
HELSON's avatar
HELSON committed
25
26
    if optimizer:
        optimizer.backward(loss)
27
28
    else:
        loss.backward()
29
    return loss