to_static.py 1.8 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sys

sys.path.append(
    os.path.join(os.path.dirname(os.path.realpath(__file__)), "./.."))

import yaml
import paddle
import paddle.nn as nn
from paddleseg.utils import logger
from inference.models import VNetModel


def main():
    model = VNetModel(
        elu=False,
        in_channels=1,
        num_classes=1,
        pretrained="/ssd2/tangshiyu/Code/EISeg-3D/experiments/3D_interseg/mrispineseg_vnet/174/checkpoints/049.pdparams",  # "pretrained_models/vnet_model.pdparams",
        kernel_size=[[2, 2, 4], [2, 2, 2], [2, 2, 2], [2, 2, 2]],
        stride_size=[[2, 2, 1], [2, 2, 1], [2, 2, 2], [2, 2, 2]],
        with_aux_output=False,
        use_leaky_relu=True,
        use_rgb_conv=False,
        use_disks=True,
        norm_radius=2,
        with_prev_mask=True, )

    model.set_dict(paddle.load("model_checkpoints/039.pdparams"))
    # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) # on or off did not change
    model.eval()
    print("Loaded trained params of model successfully")

    new_net = paddle.jit.to_static(
        model,
        input_spec=[
            paddle.static.InputSpec(
                shape=[None, 1, None, None, None], dtype="float32"),
            paddle.static.InputSpec(
                shape=[None, 3, None, None, None],
                dtype="float32"),  # 16, 48, 4
        ], )

    paddle.jit.save(new_net, "output_cpu/static_Vnet_model")

    yml_file = os.path.join("output_cpu/static_VNet_model", "vnet_deploy.yaml")
    with open(yml_file, "w") as file:
        data = {
            "Deploy": {
                "model": "static_Vnet_model.pdmodel",
                "params": "static_Vnet_model.pdiparams"
            }
        }
        yaml.dump(data, file)

    logger.info("Model is saved in {}".format("output_cpu"))


if __name__ == "__main__":
    main()