"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "3890ea1490206d1ffdbe2b2f0eb62fea204731d5"
Commit 3ddff783 authored by thomwolf's avatar thomwolf
Browse files

clean up + mask is long

parent 88c10379
...@@ -22,10 +22,10 @@ import csv ...@@ -22,10 +22,10 @@ import csv
import os import os
import logging import logging
import argparse import argparse
import random import random
import numpy as np
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -102,7 +102,7 @@ class MrpcProcessor(DataProcessor): ...@@ -102,7 +102,7 @@ class MrpcProcessor(DataProcessor):
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
print("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv"))) logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
return self._create_examples( return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
...@@ -420,7 +420,7 @@ def main(): ...@@ -420,7 +420,7 @@ def main():
n_gpu = 1 n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1: if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
...@@ -516,7 +516,7 @@ def main(): ...@@ -516,7 +516,7 @@ def main():
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, (input_ids, input_mask, segment_ids, label_ids) in enumerate(tqdm(train_dataloader, desc="Iteration")):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
...@@ -559,7 +559,7 @@ def main(): ...@@ -559,7 +559,7 @@ def main():
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment