Commit dee09a40 authored by thomwolf's avatar thomwolf
Browse files

various fixes

parent 2c731fd1
......@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
model = modeling.BertModel(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
""" def __init__(self, config, num_labels):
"""
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__()
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
......
......@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
type = str,
help = "Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case",
default = True,
type = bool,
default = False,
action='store_true',
help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.")
parser.add_argument("--max_seq_length",
default = 128,
......@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
"than this will be padded.")
parser.add_argument("--do_train",
default = False,
type = bool,
action='store_true',
help = "Whether to run training.")
parser.add_argument("--do_eval",
default = False,
type = bool,
action='store_true',
help = "Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size",
default = 32,
......@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
help = "How often to save the model checkpoint.")
parser.add_argument("--no_cuda",
default = False,
type = bool,
action='store_true',
help = "Whether not to use CUDA when available")
parser.add_argument("--local_rank",
type=int,
......@@ -490,6 +490,7 @@ def main():
warmup=args.warmup_proportion,
t_total=num_train_steps)
global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer)
......@@ -511,7 +512,6 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train()
global_step = 0
for epoch in args.num_train_epochs:
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device)
......@@ -552,9 +552,11 @@ def main():
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy = accuracy(logits, label_ids)
eval_loss += tmp_eval_loss.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment