Commit 74dbba64 authored by MottoX's avatar MottoX
Browse files

Prepare optimizer only when args.do_train is True

parent 3ae8c8be
......@@ -534,6 +534,7 @@ def main():
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
......
......@@ -763,6 +763,7 @@ def main():
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
......
......@@ -183,6 +183,7 @@ def main():
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
......
......@@ -922,6 +922,7 @@ def main():
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used
......
......@@ -385,6 +385,7 @@ def main():
model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment