"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "ad1a57a2b9f9a53581e0358cf210698877f617e4"
Commit 08fbe9a9 authored by mashun1's avatar mashun1
Browse files

vgg16-qat

parent 38e4792a
...@@ -55,12 +55,12 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中 ...@@ -55,12 +55,12 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
# --epochs表示训练或校准回合数 # --epochs表示训练或校准回合数
# --resume表示继续训练 # --resume表示继续训练
# --calibrate表示校准(在训练基础模型时不能使用此参数) # --qat表示校准(在训练基础模型时不能使用此参数)
CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 main.py --epochs=N --resume --calibrate --batch_size=N --lr=X --num_classes=10 CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 main.py --epochs=N --resume --qat --batch_size=N --lr=X --num_classes=10
## 推理 ## 推理
trtexec --onnx=/path/to/onnx --saveEngine=/path/to/save --int8 trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/qat/last.trt --int8
python eval.py --device=0 python eval.py --device=0
...@@ -70,7 +70,7 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中 ...@@ -70,7 +70,7 @@ VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中
### 精度 ### 精度
||原始模型|QAT校准模型|ONNX模型|TensorRT模型| ||原始模型|QAT模型|ONNX模型|TensorRT模型|
|:---|:---|:---|:---|:---| |:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184| |Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s| |推理时间|5.5764s|13.7603s|4.2848s|2.9893s|
......
...@@ -12,7 +12,7 @@ import torch.distributed as dist ...@@ -12,7 +12,7 @@ import torch.distributed as dist
from tqdm import tqdm from tqdm import tqdm
from utils.data import prepare_dataloader from utils.data import prepare_dataloader
from utils.calibrate import * from utils.qat import *
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from pytorch_quantization import nn as quant_nn from pytorch_quantization import nn as quant_nn
...@@ -26,10 +26,10 @@ def prepare_training_obj(lr: float = 1e-3, ...@@ -26,10 +26,10 @@ def prepare_training_obj(lr: float = 1e-3,
num_classes=10, num_classes=10,
ckpt_root: str = '', ckpt_root: str = '',
resume: bool = True, resume: bool = True,
calibrate: bool = True): qat: bool = True):
model = vgg16(num_classes=num_classes) model = vgg16(num_classes=num_classes)
if resume or calibrate: if resume or qat:
model.load_state_dict(torch.load(os.path.join(ckpt_root, "pretrained_model.pth"), map_location="cpu")) model.load_state_dict(torch.load(os.path.join(ckpt_root, "pretrained_model.pth"), map_location="cpu"))
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
...@@ -87,7 +87,7 @@ def pretrain(args): ...@@ -87,7 +87,7 @@ def pretrain(args):
rank = dist.get_rank() rank = dist.get_rank()
model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, calibrate=args.calibrate) model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, qat=args.qat)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
model.to(device) model.to(device)
...@@ -131,7 +131,7 @@ def pretrain(args): ...@@ -131,7 +131,7 @@ def pretrain(args):
cleanup() cleanup()
def calibrate(args): def qat(args):
dist.init_process_group('nccl') dist.init_process_group('nccl')
rank = dist.get_rank() rank = dist.get_rank()
...@@ -139,9 +139,9 @@ def calibrate(args): ...@@ -139,9 +139,9 @@ def calibrate(args):
quant_modules.initialize() quant_modules.initialize()
if args.resume: if args.resume:
model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/calibrated", resume=args.resume, calibrate=args.calibrate) model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/qat", resume=args.resume, qat=args.qat)
else: else:
model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, calibrate=args.calibrate) model, optimizer, lr_scheduler, loss_fc = prepare_training_obj(args.lr, ckpt_root="./checkpoints/pretrained", resume=args.resume, qat=args.qat)
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
model.to(device) model.to(device)
...@@ -178,7 +178,7 @@ def calibrate(args): ...@@ -178,7 +178,7 @@ def calibrate(args):
if (epoch + 1) % 5 == 0: if (epoch + 1) % 5 == 0:
# save checkpoints and lr. # save checkpoints and lr.
ckpt_path = "./checkpoints/calibrated" ckpt_path = "./checkpoints/qat"
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path) os.makedirs(ckpt_path)
...@@ -192,16 +192,16 @@ def calibrate(args): ...@@ -192,16 +192,16 @@ def calibrate(args):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
jit_model = torch.jit.trace(model, torch.randn((16, 3, 32, 32)).to(device)) jit_model = torch.jit.trace(model, torch.randn((16, 3, 32, 32)).to(device))
# torch.jit.save(jit_model, "./checkpoints/calibrated/pretrained_model.jit") # torch.jit.save(jit_model, "./checkpoints/qat/pretrained_model.jit")
jit_model.eval() jit_model.eval()
torch.onnx.export(jit_model.to(device), torch.randn((16, 3, 32, 32)).to(device), "checkpoints/calibrated/pretrained_qat.onnx") torch.onnx.export(jit_model.to(device), torch.randn((16, 3, 32, 32)).to(device), "checkpoints/qat/pretrained_qat.onnx")
cleanup() cleanup()
def main(args): def main(args):
if args.calibrate: if args.qat:
calibrate(args) qat(args)
else: else:
pretrain(args) pretrain(args)
...@@ -221,7 +221,7 @@ if __name__ == "__main__": ...@@ -221,7 +221,7 @@ if __name__ == "__main__":
parser.add_argument("--resume", action="store_true") parser.add_argument("--resume", action="store_true")
parser.add_argument("--calibrate", action="store_true") parser.add_argument("--qat", action="store_true")
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment