Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenzk
bert4torch_pytorch
Commits
5f82b770
"vscode:/vscode.git/clone" did not exist on "01733238a67d6a23a9c11349cf799172df60597d"
Commit
5f82b770
authored
Jan 17, 2024
by
yangzhong
Browse files
添加amp参数开关控制
parent
4edfa95d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
1 deletion
+16
-1
examples/sequence_labeling/crf.py
examples/sequence_labeling/crf.py
+16
-1
No files found.
examples/sequence_labeling/crf.py
View file @
5f82b770
...
@@ -30,6 +30,16 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
...
@@ -30,6 +30,16 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
# 固定seed
seed_everything
(
42
)
seed_everything
(
42
)
# 添加amp参数开关
parser
=
argparse
.
ArgumentParser
(
description
=
'bert4torch training'
)
#parser.add_argument('--use-amp', type=bool, default=True, help='Use automatic mixed precision (AMP)')
parser
.
add_argument
(
"--use-amp"
,
action
=
"store_true"
,
help
=
"Run model AMP (automatic mixed precision) mode."
,
)
args
=
parser
.
parse_args
()
# 加载数据集
# 加载数据集
class
MyDataset
(
ListDataset
):
class
MyDataset
(
ListDataset
):
@
staticmethod
@
staticmethod
...
@@ -107,7 +117,12 @@ class Loss(nn.Module):
...
@@ -107,7 +117,12 @@ class Loss(nn.Module):
def
forward
(
self
,
outputs
,
labels
):
def
forward
(
self
,
outputs
,
labels
):
return
model
.
crf
(
*
outputs
,
labels
)
return
model
.
crf
(
*
outputs
,
labels
)
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
))
# fp32
if
args
.
use_amp
:
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
),
use_amp
=
True
)
# 使用 AMP 进行训练fp16
else
:
model
.
compile
(
loss
=
Loss
(),
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
2e-5
),
use_amp
=
False
)
# 不使用 AMP 进行训练 fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5)) # fp32
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
# model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5), use_amp=True) # fp16
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment