Commit da1e4e53 authored by VictorSanh's avatar VictorSanh
Browse files

some fixes in `train.py` for loading previous checkpoint

parent 0d8f8848
...@@ -143,6 +143,8 @@ def main(): ...@@ -143,6 +143,8 @@ def main():
with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f: with open(os.path.join(args.dump_path, 'parameters.json'), 'w') as f:
json.dump(vars(args), f, indent=4) json.dump(vars(args), f, indent=4)
git_log(args.dump_path) git_log(args.dump_path)
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
### TOKENIZER ### ### TOKENIZER ###
...@@ -177,31 +179,18 @@ def main(): ...@@ -177,31 +179,18 @@ def main():
## STUDENT ## ## STUDENT ##
assert (args.from_pretrained_weights is None and args.from_pretrained_config is None) or \
(args.from_pretrained_weights is not None and args.from_pretrained_config is not None)
if args.from_pretrained_weights is not None: if args.from_pretrained_weights is not None:
assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json')) assert os.path.isfile(os.path.join(args.from_pretrained_weights))
assert os.path.isfile(os.path.join(args.from_pretrained, 'config.json')) assert os.path.isfile(os.path.join(args.from_pretrained_config))
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}') logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}') logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config) stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config)
student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights, student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config) config=stu_architecture_config)
else: else:
args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DilBertConfig(args) stu_architecture_config = DilBertConfig(**vars(args))
student = DilBertForMaskedLM(stu_architecture_config) student = DilBertForMaskedLM(stu_architecture_config)
# student = Model(vocab_size=args.vocab_size,
# max_position_embeddings=args.max_position_embeddings,
# sinusoidal_pos_embds=args.sinusoidal_pos_embds,
# n_layers=args.n_layers,
# n_heads=args.n_heads,
# dim=args.dim,
# dropout=args.dropout,
# attention_dropout=args.attention_dropout,
# activation=args.activation,
# initializer_range=args.initializer_range,
# tie_weights=args.tie_weights)
if args.n_gpu > 0: if args.n_gpu > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment