Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenzk
bert4torch_pytorch
Commits
13c631c3
Commit
13c631c3
authored
Jan 17, 2024
by
yangzhong
Browse files
添加amp参数开关控制
parent
4edfa95d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
4 deletions
+36
-4
examples/sequence_labeling/crf.py
examples/sequence_labeling/crf.py
+16
-1
examples/sequence_labeling/crf_ddp.py
examples/sequence_labeling/crf_ddp.py
+16
-1
examples/sequence_labeling/multi_train.sh
examples/sequence_labeling/multi_train.sh
+2
-1
examples/sequence_labeling/single_train.sh
examples/sequence_labeling/single_train.sh
+2
-1
No files found.
examples/sequence_labeling/crf.py
View file @
13c631c3
...
@@ -30,6 +30,16 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
...
@@ -30,6 +30,16 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
# 固定seed
seed_everything
(
42
)
seed_everything
(
42
)
# 添加amp参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
args
=
parser
.
parse_args
()
# 加载数据集
# 加载数据集
class
MyDataset
(
ListDataset
):
class
MyDataset
(
ListDataset
):
@
staticmethod
@
staticmethod
...
@@ -107,7 +117,12 @@ class Loss(nn.Module):
...
@@ -107,7 +117,12 @@ class Loss(nn.Module):
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
model
.
crf
(
*
outputs
,
labels
)
return
model
.
crf
(
*
outputs
,
labels
)
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
))
# fp32
if
args
.
use_amp
:
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
),
use_amp
=
True
)
# 使用 AMP 进行训练fp16
else
:
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
),
use_amp
=
False
)
# 不使用 AMP 进行训练 fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
...
...
examples/sequence_labeling/crf_ddp.py
View file @
13c631c3
...
@@ -38,6 +38,16 @@ torch.distributed.init_process_group(backend='nccl')
...
@@ -38,6 +38,16 @@ torch.distributed.init_process_group(backend='nccl')
# 固定seed
# 固定seed
seed_everything
(
42
)
seed_everything
(
42
)
# 添加amp参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
args
=
parser
.
parse_args
()
# 加载数据集
# 加载数据集
class
MyDataset
(
ListDataset
):
class
MyDataset
(
ListDataset
):
@
staticmethod
@
staticmethod
...
@@ -120,7 +130,12 @@ class Loss(nn.Module):
...
@@ -120,7 +130,12 @@ class Loss(nn.Module):
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
model
.
module
.
crf
(
*
outputs
,
labels
)
return
model
.
module
.
crf
(
*
outputs
,
labels
)
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
))
# fp32
if
args
.
use_amp
:
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
),
use_amp
=
True
)
# 使用 AMP 进行训练fp16
else
:
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
),
use_amp
=
False
)
# 不使用 AMP 进行训练 fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32
# 定义使用的loss和optimizer,这里支持自定义
# 定义使用的loss和optimizer,这里支持自定义
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
#compile(self, loss, optimizer, scheduler=None, max_grad_norm=None, use_amp=False, metrics=None, adversarial_train={'name': ''}):
#compile(self, loss, optimizer, scheduler=None, max_grad_norm=None, use_amp=False, metrics=None, adversarial_train={'name': ''}):
...
...
examples/sequence_labeling/multi_train.sh
View file @
13c631c3
...
@@ -13,4 +13,5 @@ export HIP_VISIBLE_DEVICES=$(seq -s, ${START} ${LAST})
...
@@ -13,4 +13,5 @@ export HIP_VISIBLE_DEVICES=$(seq -s, ${START} ${LAST})
export
HSA_FORCE_FINE_GRAIN_PCIE
=
1
export
HSA_FORCE_FINE_GRAIN_PCIE
=
1
logfile
=
bert_base_
${
NUM
}
dcu_
`
date
+%Y%m%d%H%M%S
`
.log
logfile
=
bert_base_
${
NUM
}
dcu_
`
date
+%Y%m%d%H%M%S
`
.log
python3
-m
torch.distributed.run
--nproc_per_node
=
${
NUM
}
crf_ddp.py 2>&1 |
tee
$logfile
python3
-m
torch.distributed.run
--nproc_per_node
=
${
NUM
}
crf_ddp.py 2>&1 |
tee
$logfile
# fp32
#python3 -m torch.distributed.run --nproc_per_node=${NUM} crf_ddp.py --use-amp 2>&1 | tee $logfile # fp16
examples/sequence_labeling/single_train.sh
View file @
13c631c3
logfile
=
bert_base_
`
date
+%Y%m%d%H%M%S
`
.log
logfile
=
bert_base_
`
date
+%Y%m%d%H%M%S
`
.log
python3 crf.py 2>&1 |
tee
$logfile
python3 crf.py 2>&1 |
tee
$logfile
# fp32
#python3 crf.py --use-amp 2>&1 | tee $logfile # fp16
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment