convert_device.py 805 Bytes
Newer Older
zcxzcx1's avatar
zcxzcx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from argparse import ArgumentParser

import torch


def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--target_device",
        "-t",
        help="device to convert to, usually 'cpu' or 'cuda'",
        default="cpu",
    )
    parser.add_argument(
        "--output_file",
        "-o",
        help="name for output model, defaults to model_file.target_device",
    )
    parser.add_argument("model_file", help="input model file path")
    args = parser.parse_args()

    if args.output_file is None:
        args.output_file = args.model_file + "." + args.target_device

    model = torch.load(args.model_file, weights_only=False)
    model.to(args.target_device)
    torch.save(model, args.output_file)


if __name__ == "__main__":
    main()