"docker/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "af53a46311df7e879413718910d181e09ded5d2e"
Commit 0540d360 authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

Fixed logging

parent 976554a4
......@@ -16,7 +16,9 @@ from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
logger = logging.getLogger(__name__)
log_format = '%(asctime)-10s: %(message)s'
logging.basicConfig(level=logging.INFO, format=log_format)
def convert_example_to_features(example, tokenizer, max_seq_length):
......@@ -68,7 +70,7 @@ class PregeneratedDataset(Dataset):
segment_ids = np.zeros(shape=(num_samples, seq_len), dtype=np.bool)
lm_label_ids = np.full(shape=(num_samples, seq_len), dtype=np.int32, fill_value=-1)
is_nexts = np.zeros(shape=(num_samples,), dtype=np.bool)
logger.info(f"Loading training examples for epoch {epoch}")
logging.info(f"Loading training examples for epoch {epoch}")
with data_file.open() as f:
for i, line in enumerate(tqdm(f, total=num_samples, desc="Training examples")):
example = json.loads(line.rstrip())
......@@ -79,7 +81,7 @@ class PregeneratedDataset(Dataset):
lm_label_ids[i] = features.lm_label_ids
is_nexts[i] = features.is_next
assert i == num_samples - 1 # Assert that the sample count metric was true
logger.info("Loading complete!")
logging.info("Loading complete!")
self.num_samples = num_samples
self.seq_len = seq_len
self.input_ids = input_ids
......@@ -132,8 +134,8 @@ def main():
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type = float, default = 0,
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
parser.add_argument("--warmup_proportion",
......@@ -179,7 +181,7 @@ def main():
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
logging.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
......@@ -195,7 +197,7 @@ def main():
torch.cuda.manual_seed_all(args.seed)
if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
logger.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
logging.warning(f"Output directory ({args.output_dir}) already exists and is not empty!")
args.output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
......@@ -258,10 +260,10 @@ def main():
t_total=num_train_optimization_steps)
global_step = 0
logger.info("***** Running training *****")
logger.info(f" Num examples = {total_train_examples}")
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
logging.info("***** Running training *****")
logging.info(f" Num examples = {total_train_examples}")
logging.info(" Batch size = %d", args.train_batch_size)
logging.info(" Num steps = %d", num_train_optimization_steps)
model.train()
for epoch in range(args.epochs):
epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
......@@ -304,7 +306,7 @@ def main():
global_step += 1
# Save a trained model
logger.info("** ** * Saving fine-tuned model ** ** * ")
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = args.output_dir / "pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment