"examples/research_projects/vscode:/vscode.git/clone" did not exist on "5d848ec07c2011d600ce5e5c1aa02a03152aea9b"
Unverified Commit 85aeae71 authored by Double_V's avatar Double_V Committed by GitHub
Browse files

Merge pull request #3002 from littletomatodonkey/dyg/add_distillation

add distillation
parents d93a445d 95d07675
...@@ -49,7 +49,7 @@ def main(): ...@@ -49,7 +49,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
best_model_dict = init_model(config, model, logger) best_model_dict = init_model(config, model)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -17,7 +17,7 @@ import sys ...@@ -17,7 +17,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse import argparse
...@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger ...@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def main(): def export_single_model(model, arch_config, save_path, logger):
FLAGS = ArgsParser().parse_args() if arch_config["algorithm"] == "SRN":
config = load_config(FLAGS.config) max_text_length = arch_config["Head"]["max_text_length"]
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [ shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 256, 1], shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec( dtype="int64"), paddle.static.InputSpec(
...@@ -71,24 +51,66 @@ def main(): ...@@ -71,24 +51,66 @@ def main():
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec": if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32 infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][ if "Transform" in arch_config and arch_config[
'Transform'] is not None and config['Architecture'][ "Transform"] is not None and arch_config["Transform"][
'Transform']['name'] == 'TPS': "name"] == "TPS":
logger.info( logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
) )
infer_shape[-1] = 100 infer_shape[-1] = 100
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32') shape=[None] + infer_shape, dtype="float32")
]) ])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info("inference model is saved to {}".format(save_path))
return
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"],
config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model)
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -61,7 +61,7 @@ def main(): ...@@ -61,7 +61,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess']) post_process_class = build_post_process(config['PostProcess'])
......
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import os import os
import sys import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -46,12 +47,18 @@ def main(): ...@@ -46,12 +47,18 @@ def main():
# build model # build model
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -107,11 +114,23 @@ def main(): ...@@ -107,11 +114,23 @@ def main():
else: else:
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: info = None
logger.info('\t result: {}'.format(rec_reuslt)) if isinstance(post_result, dict):
if len(rec_reuslt) >= 2: rec_info = dict()
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( for key in post_result:
rec_reuslt[1]) + "\n") if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": post_result[key][0][1],
}
info = json.dumps(rec_info)
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
if info is not None:
logger.info("\t result: {}".format(info))
fout.write(file + "\t" + info)
logger.info("success!") logger.info("success!")
......
...@@ -386,7 +386,7 @@ def preprocess(is_train=False): ...@@ -386,7 +386,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet' 'CLS', 'PGNet', 'Distillation'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
...@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer): ...@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment