Commit 936eb4c3 authored by VictorSanh's avatar VictorSanh
Browse files

FIX small bugs in `run_classifier_pytorch.py`

parent cc228089
...@@ -410,8 +410,8 @@ def input_fn_builder(features, seq_length, train_batch_size): ...@@ -410,8 +410,8 @@ def input_fn_builder(features, seq_length, train_batch_size):
input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long) input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.Long)
input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long) input_mask_tensor = torch.tensor(all_input_mask, dtype=torch.Long)
segment_tensor = torch.tensor(all_segment, dtype=torch.Long) segment_tensor = torch.tensor(all_segment_ids, dtype=torch.Long)
label_tensor = torch.tensor(all_label, dtype=torch.Long) label_tensor = torch.tensor(all_label_ids, dtype=torch.Long)
train_data = TensorDataset(input_ids_tensor, input_mask_tensor, train_data = TensorDataset(input_ids_tensor, input_mask_tensor,
segment_tensor, label_tensor) segment_tensor, label_tensor)
...@@ -512,7 +512,7 @@ def main(): ...@@ -512,7 +512,7 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for epoch in args.num_train_epochs: for epoch in range(args.num_train_epochs):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment