Commit 58935387 authored by yangzhong's avatar yangzhong
Browse files

修改参数为argparse选项

parent 819d90cc
......@@ -14,32 +14,44 @@ from bert4torch.layers import CRF
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
import argparse
# 添加参数开关
parser = argparse.ArgumentParser(description='bert4torch training')
parser.add_argument(
"--use-amp",
action="store_true",
help="Run model AMP (automatic mixed precision) mode.",
)
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument("--root-path", default='/root', type=str, help='root path')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
maxlen = 256
batch_size = 64
batch_size=args.batch_size
categories = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
categories_id2label = {i: k for i, k in enumerate(categories)}
categories_label2id = {k: i for i, k in enumerate(categories)}
# BERT base
config_path = '/bert4torch/datasets/bert-base-chinese/config.json'
checkpoint_path = '/bert4torch/datasets/bert-base-chinese/pytorch_model.bin'
dict_path = '/bert4torch/datasets/bert-base-chinese/vocab.txt'
root_path = args.root_path
config_path = root_path + '/bert-base-chinese/config.json'
checkpoint_path = root_path + '/bert-base-chinese/pytorch_model.bin'
dict_path = root_path + '/bert-base-chinese/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 添加amp参数开关
parser = argparse.ArgumentParser(description='bert4torch training')
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser.add_argument(
"--use-amp",
action="store_true",
help="Run model AMP (automatic mixed precision) mode.",
)
args = parser.parse_args()
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
......@@ -87,8 +99,8 @@ def collate_fn(batch):
return batch_token_ids, batch_labels
# 转换数据集
train_dataloader = DataLoader(MyDataset('/bert4torch/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('/bert4torch/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
train_dataloader = DataLoader(MyDataset(root_path + '/bert-base-chinese/china-people-daily-ner-corpus/example.train'), batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn) # shuffle=True
valid_dataloader = DataLoader(MyDataset(root_path + '/bert-base-chinese/china-people-daily-ner-corpus/example.dev'), batch_size=args.batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
......@@ -190,7 +202,8 @@ if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
#model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
model.fit(train_dataloader, epochs=args.epochs, steps_per_epoch=None, callbacks=[evaluator])
else:
......
......@@ -16,18 +16,40 @@ from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
from bert4torch.models import BaseModelDDP
import os
import argparse
# 添加参数开关
parser = argparse.ArgumentParser(description='bert4torch training')
parser.add_argument(
"--use-amp",
action="store_true",
help="Run model AMP (automatic mixed precision) mode.",
)
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument("--root-path", default='/root', type=str, help='root path')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
maxlen = 256
batch_size = 64
batch_size = args.batch_size
categories = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
categories_id2label = {i: k for i, k in enumerate(categories)}
categories_label2id = {k: i for i, k in enumerate(categories)}
# BERT base
#config_path = '/datasets/bert-base-chinese/bert_config.json'
config_path = '/bert4torch/datasets/bert-base-chinese/config.json'
checkpoint_path = '/bert4torch/datasets/bert-base-chinese/pytorch_model.bin'
dict_path = '/bert4torch/datasets/bert-base-chinese/vocab.txt'
root_path = args.root_path
config_path = root_path + '/bert-base-chinese/config.json'
checkpoint_path = root_path + '/bert-base-chinese/pytorch_model.bin'
dict_path = root_path + '/bert-base-chinese/vocab.txt'
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
local_rank = int(os.environ['LOCAL_RANK'])
print("local_rank ", local_rank)
......@@ -38,16 +60,6 @@ torch.distributed.init_process_group(backend='nccl')
# 固定seed
seed_everything(42)
# 添加amp参数开关
parser = argparse.ArgumentParser(description='bert4torch training')
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser.add_argument(
"--use-amp",
action="store_true",
help="Run model AMP (automatic mixed precision) mode.",
)
args = parser.parse_args()
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
......@@ -95,11 +107,10 @@ def collate_fn(batch):
return batch_token_ids, batch_labels
# 转换数据集
#train_dataloader = DataLoader(MyDataset('/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
train_dataset = MyDataset('/bert4torch/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train')
train_dataset = MyDataset(root_path + '/bert-base-chinese/china-people-daily-ner-corpus/example.train')
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('/bert4torch/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset(root_path + '/bert-base-chinese/china-people-daily-ner-corpus/example.dev'), batch_size=args.batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
......@@ -206,7 +217,8 @@ if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
#model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
model.fit(train_dataloader, epochs=args.epochs, steps_per_epoch=None, callbacks=[evaluator])
else:
......
......@@ -13,5 +13,5 @@ export HIP_VISIBLE_DEVICES=$(seq -s, ${START} ${LAST})
export HSA_FORCE_FINE_GRAIN_PCIE=1
logfile=bert_base_${NUM}dcu_`date +%Y%m%d%H%M%S`.log
python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py 2>&1 | tee $logfile # fp32
#python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --use-amp 2>&1 | tee $logfile # fp16
python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --batch-size=64 --root-path=/bert4torch/datasets --epochs=20 2>&1 | tee $logfile # fp32
#python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --use-amp --batch-size=64 --root-path=/bert4torch/datasets --epochs=20 2>&1 | tee $logfile # fp16
logfile=bert_base_`date +%Y%m%d%H%M%S`.log
python3 crf.py 2>&1 | tee $logfile # fp32
#python3 crf.py --use-amp 2>&1 | tee $logfile # fp16
#python3 crf.py --use-amp --batch-size=64 --root-path=/bert4torch/datasets --epochs=20 2>&1 | tee $logfile # fp16
python3 crf.py --batch-size=64 --root-path=/bert4torch/datasets --epochs=20 2>&1 | tee $logfile # fp32
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment