"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "15e553b15aa42775ad3058fb8fa81275c6f9ee1d"
Unverified Commit f9a2b26a authored by littletomatodonkey's avatar littletomatodonkey Committed by GitHub
Browse files

fix quant logic (#5806)

* fix quant logic

* fix undef

* fix doc
parent 3d692957
......@@ -118,6 +118,11 @@ def main(config, device, logger, vdl_writer):
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
pre_best_model_dict = dict()
# load fp32 model to begin quantization
if config["Global"]["pretrained_model"] is not None:
pre_best_model_dict = load_model(config, model)
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
......@@ -134,10 +139,12 @@ def main(config, device, logger, vdl_writer):
step_each_epoch=len(train_dataloader),
parameters=model.parameters())
# resume PACT training process
if config["Global"]["checkpoints"] is not None:
pre_best_model_dict = load_model(config, model, optimizer)
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = load_model(config, model, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment