"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a573777901e662ec2e565be312ffaeedef6effec"
Commit 290633b8 authored by VictorSanh's avatar VictorSanh
Browse files

Fix `args.gradient_accumulation_steps` used before assigment.

parent 649e9774
...@@ -404,6 +404,10 @@ def main(): ...@@ -404,6 +404,10 @@ def main():
type=int, type=int,
default=42, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
args = parser.parse_args() args = parser.parse_args()
processors = { processors = {
...@@ -469,7 +473,7 @@ def main(): ...@@ -469,7 +473,7 @@ def main():
model = BertForSequenceClassification(bert_config, len(label_list)) model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if args.local_rank != -1: if args.local_rank != -1:
......
...@@ -739,7 +739,11 @@ def main(): ...@@ -739,7 +739,11 @@ def main():
type=int, type=int,
default=42, default=42,
help="random seed for initialization") help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
args = parser.parse_args() args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment