Commit 9f995b99 authored by VictorSanh's avatar VictorSanh
Browse files

minor fixes

parent 3fe5c8e8
...@@ -115,7 +115,6 @@ class Distiller: ...@@ -115,7 +115,6 @@ class Distiller:
betas=(0.9, 0.98)) betas=(0.9, 0.98))
warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
logger.info(f'--- Scheduler: {params.scheduler_type}')
self.scheduler = WarmupLinearSchedule(self.optimizer, self.scheduler = WarmupLinearSchedule(self.optimizer,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
......
...@@ -204,8 +204,8 @@ def main(): ...@@ -204,8 +204,8 @@ def main():
## STUDENT ## ## STUDENT ##
if args.from_pretrained_weights is not None: if args.from_pretrained_weights is not None:
assert os.path.isfile(os.path.join(args.from_pretrained_weights)) assert os.path.isfile(args.from_pretrained_weights)
assert os.path.isfile(os.path.join(args.from_pretrained_config)) assert os.path.isfile(args.from_pretrained_config)
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config) stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment