Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenzk
bert4torch_pytorch
Commits
58935387
Commit
58935387
authored
Jan 31, 2024
by
yangzhong
Browse files
修改参数为argparse选项
parent
819d90cc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
66 additions
and
41 deletions
+66
-41
examples/sequence_labeling/crf.py
examples/sequence_labeling/crf.py
+30
-17
examples/sequence_labeling/crf_ddp.py
examples/sequence_labeling/crf_ddp.py
+32
-20
examples/sequence_labeling/multi_train.sh
examples/sequence_labeling/multi_train.sh
+2
-2
examples/sequence_labeling/single_train.sh
examples/sequence_labeling/single_train.sh
+2
-2
No files found.
examples/sequence_labeling/crf.py
View file @
58935387
...
...
@@ -14,32 +14,44 @@ from bert4torch.layers import CRF
from
bert4torch.tokenizers
import
Tokenizer
from
bert4torch.models
import
build_transformer_model
,
BaseModel
from
tqdm
import
tqdm
import
argparse
# 添加参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
default
=
256
,
type
=
int
,
metavar
=
'N'
,
help
=
'mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel'
)
parser
.
add_argument
(
"--root-path"
,
default
=
'/root'
,
type
=
str
,
help
=
'root path'
)
parser
.
add_argument
(
'--epochs'
,
default
=
20
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of total epochs to run'
)
args
=
parser
.
parse_args
()
maxlen
=
256
batch_size
=
64
batch_size
=
args
.
batch_size
categories
=
[
'O'
,
'B-LOC'
,
'I-LOC'
,
'B-PER'
,
'I-PER'
,
'B-ORG'
,
'I-ORG'
]
categories_id2label
=
{
i
:
k
for
i
,
k
in
enumerate
(
categories
)}
categories_label2id
=
{
k
:
i
for
i
,
k
in
enumerate
(
categories
)}
# BERT base
config_path
=
'/bert4torch/datasets/bert-base-chinese/config.json'
checkpoint_path
=
'/bert4torch/datasets/bert-base-chinese/pytorch_model.bin'
dict_path
=
'/bert4torch/datasets/bert-base-chinese/vocab.txt'
root_path
=
args
.
root_path
config_path
=
root_path
+
'/bert-base-chinese/config.json'
checkpoint_path
=
root_path
+
'/bert-base-chinese/pytorch_model.bin'
dict_path
=
root_path
+
'/bert-base-chinese/vocab.txt'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
# 固定seed
seed_everything
(
42
)
# 添加amp参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
args
=
parser
.
parse_args
()
# 加载数据集
class
MyDataset
(
ListDataset
):
@
staticmethod
...
...
@@ -87,8 +99,8 @@ def collate_fn(batch):
return
batch_token_ids
,
batch_labels
# 转换数据集
train_dataloader
=
DataLoader
(
MyDataset
(
'/bert4torch/datasets
/bert-base-chinese/china-people-daily-ner-corpus/example.train'
),
batch_size
=
batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
)
valid_dataloader
=
DataLoader
(
MyDataset
(
'/bert4torch/datasets
/bert-base-chinese/china-people-daily-ner-corpus/example.dev'
),
batch_size
=
batch_size
,
collate_fn
=
collate_fn
)
train_dataloader
=
DataLoader
(
MyDataset
(
root_path
+
'
/bert-base-chinese/china-people-daily-ner-corpus/example.train'
),
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
collate_fn
=
collate_fn
)
# shuffle=True
valid_dataloader
=
DataLoader
(
MyDataset
(
root_path
+
'
/bert-base-chinese/china-people-daily-ner-corpus/example.dev'
),
batch_size
=
args
.
batch_size
,
collate_fn
=
collate_fn
)
# 定义bert上的模型结构
class
Model
(
BaseModel
):
...
...
@@ -190,7 +202,8 @@ if __name__ == '__main__':
evaluator
=
Evaluator
()
model
.
fit
(
train_dataloader
,
epochs
=
20
,
steps_per_epoch
=
None
,
callbacks
=
[
evaluator
])
#model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
model
.
fit
(
train_dataloader
,
epochs
=
args
.
epochs
,
steps_per_epoch
=
None
,
callbacks
=
[
evaluator
])
else
:
...
...
examples/sequence_labeling/crf_ddp.py
View file @
58935387
...
...
@@ -16,18 +16,40 @@ from bert4torch.models import build_transformer_model, BaseModel
from
tqdm
import
tqdm
from
bert4torch.models
import
BaseModelDDP
import
os
import
argparse
# 添加参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
default
=
256
,
type
=
int
,
metavar
=
'N'
,
help
=
'mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel'
)
parser
.
add_argument
(
"--root-path"
,
default
=
'/root'
,
type
=
str
,
help
=
'root path'
)
parser
.
add_argument
(
'--epochs'
,
default
=
20
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of total epochs to run'
)
args
=
parser
.
parse_args
()
maxlen
=
256
batch_size
=
64
batch_size
=
args
.
batch_size
categories
=
[
'O'
,
'B-LOC'
,
'I-LOC'
,
'B-PER'
,
'I-PER'
,
'B-ORG'
,
'I-ORG'
]
categories_id2label
=
{
i
:
k
for
i
,
k
in
enumerate
(
categories
)}
categories_label2id
=
{
k
:
i
for
i
,
k
in
enumerate
(
categories
)}
# BERT base
#config_path = '/datasets/bert-base-chinese/bert_config.json'
config_path
=
'/bert4torch/datasets/bert-base-chinese/config.json'
checkpoint_path
=
'/bert4torch/datasets/bert-base-chinese/pytorch_model.bin'
dict_path
=
'/bert4torch/datasets/bert-base-chinese/vocab.txt'
root_path
=
args
.
root_path
config_path
=
root_path
+
'/bert-base-chinese/config.json'
checkpoint_path
=
root_path
+
'/bert-base-chinese/pytorch_model.bin'
dict_path
=
root_path
+
'/bert-base-chinese/vocab.txt'
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
local_rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
print
(
"local_rank "
,
local_rank
)
...
...
@@ -38,16 +60,6 @@ torch.distributed.init_process_group(backend='nccl')
# 固定seed
seed_everything
(
42
)
# 添加amp参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
args
=
parser
.
parse_args
()
# 加载数据集
class
MyDataset
(
ListDataset
):
@
staticmethod
...
...
@@ -95,11 +107,10 @@ def collate_fn(batch):
return
batch_token_ids
,
batch_labels
# 转换数据集
#train_dataloader = DataLoader(MyDataset('/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
train_dataset
=
MyDataset
(
'/bert4torch/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train'
)
train_dataset
=
MyDataset
(
root_path
+
'/bert-base-chinese/china-people-daily-ner-corpus/example.train'
)
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
sampler
=
train_sampler
,
collate_fn
=
collate_fn
)
valid_dataloader
=
DataLoader
(
MyDataset
(
'/bert4torch/datasets
/bert-base-chinese/china-people-daily-ner-corpus/example.dev'
),
batch_size
=
batch_size
,
collate_fn
=
collate_fn
)
train_dataloader
=
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
sampler
=
train_sampler
,
collate_fn
=
collate_fn
)
valid_dataloader
=
DataLoader
(
MyDataset
(
root_path
+
'
/bert-base-chinese/china-people-daily-ner-corpus/example.dev'
),
batch_size
=
args
.
batch_size
,
collate_fn
=
collate_fn
)
# 定义bert上的模型结构
class
Model
(
BaseModel
):
...
...
@@ -206,7 +217,8 @@ if __name__ == '__main__':
evaluator
=
Evaluator
()
model
.
fit
(
train_dataloader
,
epochs
=
20
,
steps_per_epoch
=
None
,
callbacks
=
[
evaluator
])
#model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
model
.
fit
(
train_dataloader
,
epochs
=
args
.
epochs
,
steps_per_epoch
=
None
,
callbacks
=
[
evaluator
])
else
:
...
...
examples/sequence_labeling/multi_train.sh
View file @
58935387
...
...
@@ -13,5 +13,5 @@ export HIP_VISIBLE_DEVICES=$(seq -s, ${START} ${LAST})
export
HSA_FORCE_FINE_GRAIN_PCIE
=
1
logfile
=
bert_base_
${
NUM
}
dcu_
`
date
+%Y%m%d%H%M%S
`
.log
python3
-m
torch.distributed.run
--nproc_per_node
=
${
NUM
}
crf_ddp.py 2>&1 |
tee
$logfile
# fp32
#python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --use-amp 2>&1 | tee $logfile # fp16
python3
-m
torch.distributed.run
--nproc_per_node
=
${
NUM
}
crf_ddp.py
--batch-size
=
64
--root-path
=
/bert4torch/datasets
--epochs
=
20
2>&1 |
tee
$logfile
# fp32
#python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --use-amp
--batch-size=64 --root-path=/bert4torch/datasets --epochs=20
2>&1 | tee $logfile # fp16
examples/sequence_labeling/single_train.sh
View file @
58935387
logfile
=
bert_base_
`
date
+%Y%m%d%H%M%S
`
.log
python3 crf.py 2>&1 |
tee
$logfile
# fp
32
#
python3 crf.py --
use-amp
2>&1 | tee $logfile
# fp
16
#
python3 crf.py
--use-amp --batch-size=64 --root-path=/bert4torch/datasets --epochs=20
2>&1 | tee $logfile # fp
16
python3 crf.py
--
batch-size
=
64
--root-path
=
/bert4torch/datasets
--epochs
=
20
2>&1 |
tee
$logfile
# fp
32
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment