Unverified Commit 0198399d authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #570 from MottoX/fix-1

Create optimizer only when args.do_train is True
parents 50fa92c0 74dbba64
...@@ -534,6 +534,7 @@ def main(): ...@@ -534,6 +534,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
......
...@@ -763,6 +763,7 @@ def main(): ...@@ -763,6 +763,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
......
...@@ -183,6 +183,7 @@ def main(): ...@@ -183,6 +183,7 @@ def main():
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
......
...@@ -922,6 +922,7 @@ def main(): ...@@ -922,6 +922,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used
......
...@@ -385,6 +385,7 @@ def main(): ...@@ -385,6 +385,7 @@ def main():
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment