"src/lib/vscode:/vscode.git/clone" did not exist on "9747f1e8415798dbb7be6ed5c215d32664db28cc"
utils_fit.py 2.56 KB
Newer Older
zhenyi's avatar
zhenyi committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import tensorflow as tf
from tqdm import tqdm


def get_train_step_fn():
    @tf.function
    def train_step(images, multiloss, targets, net, optimizer):
        with tf.GradientTape() as tape:
            prediction = net(images, training=True)
            loss_value = multiloss(targets, prediction)
        grads = tape.gradient(loss_value, net.trainable_variables)
        optimizer.apply_gradients(zip(grads, net.trainable_variables))
        return loss_value
    return train_step

@tf.function
def val_step(images, multiloss, targets, net):
    prediction = net(images)
    loss_value = multiloss(targets, prediction)
    return loss_value

def fit_one_epoch(net, multiloss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch):
    train_step  = get_train_step_fn()
    loss        = 0
    val_loss    = 0
    print('Start Train')
    with tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen):
            if iteration >= epoch_step:
                break
            images, targets = batch[0], batch[1]
            targets         = tf.convert_to_tensor(targets)
            
            loss_value      = train_step(images, multiloss, targets, net, optimizer)
            loss            = loss_value + loss

            pbar.set_postfix(**{'loss'  : float(loss) / (iteration + 1), 
                                'lr'    : optimizer._decayed_lr(tf.float32).numpy()})
            pbar.update(1)
    print('Finish Train')

    print('Start Validation')
    with tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) as pbar:
        for iteration, batch in enumerate(gen_val):
            if iteration>=epoch_step_val:
                break
            images, targets = batch[0], batch[1]
            targets         = tf.convert_to_tensor(targets)

            loss_value      = val_step(images, multiloss, targets, net)
            val_loss        = val_loss + loss_value

            pbar.set_postfix(**{'loss' : float(val_loss)/ (iteration + 1)})
            pbar.update(1)
    print('Finish Validation')

    logs = {'loss': loss.numpy() / (epoch_step+1), 'val_loss': val_loss.numpy() / (epoch_step_val+1)}
    loss_history.on_epoch_end([], logs)
    print('Epoch:'+ str(epoch+1) + '/' + str(Epoch))
    print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / (epoch_step + 1), val_loss / (epoch_step_val + 1)))
    net.save_weights('logs/ep%03d-loss%.3f-val_loss%.3f.h5' % ((epoch + 1), loss / (epoch_step + 1) ,val_loss / (epoch_step_val + 1)))