"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "49a77ac16ff31b5ea938d8796bf0f4b5428774e6"
Commit dee09a40 authored by thomwolf's avatar thomwolf
Browse files

various fixes

parent 2c731fd1
...@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module): ...@@ -412,7 +412,8 @@ class BertForSequenceClassification(nn.Module):
model = modeling.BertModel(config, num_labels) model = modeling.BertModel(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" def __init__(self, config, num_labels): """
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__() super(BertForSequenceClassification, self).__init__()
self.bert = BertModel(config) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
......
...@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint", ...@@ -73,8 +73,8 @@ parser.add_argument("--init_checkpoint",
type = str, type = str,
help = "Initial checkpoint (usually from a pre-trained BERT model).") help = "Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case", parser.add_argument("--do_lower_case",
default = True, default = False,
type = bool, action='store_true',
help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.") help = "Whether to lower case the input text. Should be True for uncased models and False for cased models.")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default = 128, default = 128,
...@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length", ...@@ -84,11 +84,11 @@ parser.add_argument("--max_seq_length",
"than this will be padded.") "than this will be padded.")
parser.add_argument("--do_train", parser.add_argument("--do_train",
default = False, default = False,
type = bool, action='store_true',
help = "Whether to run training.") help = "Whether to run training.")
parser.add_argument("--do_eval", parser.add_argument("--do_eval",
default = False, default = False,
type = bool, action='store_true',
help = "Whether to run eval on the dev set.") help = "Whether to run eval on the dev set.")
parser.add_argument("--train_batch_size", parser.add_argument("--train_batch_size",
default = 32, default = 32,
...@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps", ...@@ -117,7 +117,7 @@ parser.add_argument("--save_checkpoints_steps",
help = "How often to save the model checkpoint.") help = "How often to save the model checkpoint.")
parser.add_argument("--no_cuda", parser.add_argument("--no_cuda",
default = False, default = False,
type = bool, action='store_true',
help = "Whether not to use CUDA when available") help = "Whether not to use CUDA when available")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
...@@ -490,6 +490,7 @@ def main(): ...@@ -490,6 +490,7 @@ def main():
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=num_train_steps)
global_step = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer) train_examples, label_list, args.max_seq_length, tokenizer)
...@@ -511,7 +512,6 @@ def main(): ...@@ -511,7 +512,6 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
global_step = 0
for epoch in args.num_train_epochs: for epoch in args.num_train_epochs:
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
...@@ -552,9 +552,11 @@ def main(): ...@@ -552,9 +552,11 @@ def main():
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
logits = logits.detach().cpu().numpy()
label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy = accuracy(logits, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)
eval_loss += tmp_eval_loss.item() eval_loss += tmp_eval_loss.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment