Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
2376ef4a
Unverified
Commit
2376ef4a
authored
Jun 02, 2023
by
Guolin Ke
Committed by
GitHub
Jun 02, 2023
Browse files
support Wandb (#29)
* add wandb support * code clean
parent
854b8890
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
72 deletions
+101
-72
examples/bert/train_bert_test.sh
examples/bert/train_bert_test.sh
+6
-3
unicore/logging/progress_bar.py
unicore/logging/progress_bar.py
+39
-38
unicore/options.py
unicore/options.py
+20
-9
unicore_cli/train.py
unicore_cli/train.py
+36
-22
No files found.
examples/bert/train_bert_test.sh
View file @
2376ef4a
...
...
@@ -2,15 +2,18 @@
[
-z
"
${
n_gpu
}
"
]
&&
n_gpu
=
$(
nvidia-smi
-L
|
wc
-l
)
export
NCCL_ASYNC_ERROR_HANDLING
=
1
export
OMP_NUM_THREADS
=
1
run_name
=
bert_example
save_dir
=
"./save/
${
run_name
}
"
mkdir
-p
${
save_dir
}
python
-m
torch.distributed.launch
--nproc_per_node
=
$n_gpu
--master_port
=
$MASTER_PORT
$(
which unicore-train
)
./example_data
--user-dir
.
--valid-subset
valid
\
--num-workers
0
--ddp-backend
=
c10d
\
--task
bert
--loss
masked_lm
--arch
bert_base
\
--optimizer
adam
--adam-betas
'(0.9, 0.98)'
--adam-eps
1e-6
--clip-norm
1.0
\
--lr-scheduler
polynomial_decay
--lr
1e-4
--warmup-updates
100
--total-num-update
10000
--batch-size
4
\
--update-freq
1
--seed
1
\
--fp16
--fp16-init-scale
4
--fp16-scale-window
256
--tensorboard-logdir
.
/tsb
/
\
--fp16
--fp16-init-scale
4
--fp16-scale-window
256
--tensorboard-logdir
$save_dir
/tsb
\
--max-update
10000
--log-interval
100
--log-format
simple
\
--save-interval-updates
5
000
--validate-interval-updates
5
000
--keep-interval-updates
30
--no-epoch-checkpoints
\
--save-dir
./
save
--save-interval-updates
1
000
--validate-interval-updates
1
000
--keep-interval-updates
30
--no-epoch-checkpoints
\
--save-dir
$
save
_dir
unicore/logging/progress_bar.py
View file @
2376ef4a
...
...
@@ -33,7 +33,9 @@ def progress_bar(
epoch
:
Optional
[
int
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
tensorboard_logdir
:
Optional
[
str
]
=
None
,
wandb_project
:
Optional
[
str
]
=
None
,
default_log_format
:
str
=
"tqdm"
,
args
=
None
,
):
if
log_format
is
None
:
log_format
=
default_log_format
...
...
@@ -52,44 +54,13 @@ def progress_bar(
raise
ValueError
(
"Unknown log format: {}"
.
format
(
log_format
))
if
tensorboard_logdir
:
try
:
# [FB only] custom wrapper for TensorBoard
import
palaas
# noqa
from
.fb_tbmf_wrapper
import
FbTbmfWrapper
bar
=
FbTbmfWrapper
(
bar
,
log_interval
)
except
ImportError
:
bar
=
TensorboardProgressBarWrapper
(
bar
,
tensorboard_logdir
)
bar
=
TensorboardProgressBarWrapper
(
bar
,
tensorboard_logdir
,
wandb_project
,
args
)
return
bar
def
build_progress_bar
(
args
,
iterator
,
epoch
:
Optional
[
int
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
default
:
str
=
"tqdm"
,
no_progress_bar
:
str
=
"none"
,
):
"""Legacy wrapper that takes an argparse.Namespace."""
if
getattr
(
args
,
"no_progress_bar"
,
False
):
default
=
no_progress_bar
if
getattr
(
args
,
"distributed_rank"
,
0
)
==
0
:
tensorboard_logdir
=
getattr
(
args
,
"tensorboard_logdir"
,
None
)
else
:
tensorboard_logdir
=
None
return
progress_bar
(
iterator
,
log_format
=
args
.
log_format
,
log_interval
=
args
.
log_interval
,
epoch
=
epoch
,
prefix
=
prefix
,
tensorboard_logdir
=
tensorboard_logdir
,
default_log_format
=
default
,
)
def
format_stat
(
stat
):
if
isinstance
(
stat
,
Number
):
stat
=
"{:g}"
.
format
(
stat
)
...
...
@@ -306,10 +277,23 @@ except ImportError:
except
ImportError
:
SummaryWriter
=
None
try
:
_wandb_inited
=
False
import
wandb
wandb_available
=
True
except
ImportError
:
wandb_available
=
False
def
_close_writers
():
for
w
in
_tensorboard_writers
.
values
():
w
.
close
()
if
_wandb_inited
:
try
:
wandb
.
finish
()
except
:
pass
atexit
.
register
(
_close_writers
)
...
...
@@ -318,7 +302,7 @@ atexit.register(_close_writers)
class
TensorboardProgressBarWrapper
(
BaseProgressBar
):
"""Log to tensorboard."""
def
__init__
(
self
,
wrapped_bar
,
tensorboard_logdir
):
def
__init__
(
self
,
wrapped_bar
,
tensorboard_logdir
,
wandb_project
,
args
):
self
.
wrapped_bar
=
wrapped_bar
self
.
tensorboard_logdir
=
tensorboard_logdir
...
...
@@ -326,6 +310,17 @@ class TensorboardProgressBarWrapper(BaseProgressBar):
logger
.
warning
(
"tensorboard not found, please install with: pip install tensorboard"
)
global
_wandb_inited
if
not
_wandb_inited
and
wandb_project
and
wandb_available
:
wandb_name
=
args
.
wandb_name
or
wandb
.
util
.
generate_id
()
wandb
.
init
(
project
=
wandb_project
,
name
=
wandb_name
,
config
=
vars
(
args
),
id
=
wandb_name
,
resume
=
"allow"
,
)
_wandb_inited
=
True
def
_writer
(
self
,
key
):
if
SummaryWriter
is
None
:
...
...
@@ -362,9 +357,15 @@ class TensorboardProgressBarWrapper(BaseProgressBar):
step
=
stats
[
"num_updates"
]
for
key
in
stats
.
keys
()
-
{
"num_updates"
}:
if
isinstance
(
stats
[
key
],
AverageMeter
):
writer
.
add_scalar
(
key
,
stats
[
key
].
val
,
step
)
val
=
stats
[
key
].
val
elif
isinstance
(
stats
[
key
],
Number
):
writer
.
add_scalar
(
key
,
stats
[
key
]
,
step
)
val
=
stats
[
key
]
elif
torch
.
is_tensor
(
stats
[
key
])
and
stats
[
key
].
numel
()
==
1
:
writer
.
add_scalar
(
key
,
stats
[
key
].
item
(),
step
)
val
=
stats
[
key
].
item
()
else
:
val
=
None
if
val
:
writer
.
add_scalar
(
key
,
val
,
step
)
if
_wandb_inited
:
wandb
.
log
({
"{}_{}"
.
format
(
tag
,
key
):
val
},
step
=
step
)
writer
.
flush
()
unicore/options.py
View file @
2376ef4a
...
...
@@ -10,8 +10,15 @@ import torch
from
typing
import
Callable
,
List
,
Optional
# this import is for backward compatibility
from
unicore.utils
import
csv_str_list
,
eval_bool
,
eval_str_dict
,
eval_str_list
,
import_user_module
# noqa
from
unicore.utils
import
(
csv_str_list
,
eval_bool
,
eval_str_dict
,
eval_str_list
,
import_user_module
,
)
# noqa
def
get_training_parser
(
default_task
=
"translation"
):
...
...
@@ -137,7 +144,7 @@ def parse_args_and_arch(
args
.
no_seed_provided
=
True
else
:
args
.
no_seed_provided
=
False
args
.
validate_with_ema
=
getattr
(
args
,
"validate_with_ema"
,
False
)
# Apply architecture configuration.
if
hasattr
(
args
,
"arch"
)
and
args
.
arch
in
ARCH_CONFIG_REGISTRY
:
...
...
@@ -149,11 +156,11 @@ def parse_args_and_arch(
return
args
def
get_parser
(
desc
,
default_task
=
'
test
'
):
def
get_parser
(
desc
,
default_task
=
"
test
"
):
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser
=
argparse
.
ArgumentParser
(
add_help
=
False
,
allow_abbrev
=
False
)
usr_parser
.
add_argument
(
'
--user-dir
'
,
default
=
None
)
usr_parser
.
add_argument
(
"
--user-dir
"
,
default
=
None
)
usr_args
,
_
=
usr_parser
.
parse_known_args
()
import_user_module
(
usr_args
)
...
...
@@ -167,6 +174,10 @@ def get_parser(desc, default_task='test'):
parser
.
add_argument
(
'--tensorboard-logdir'
,
metavar
=
'DIR'
,
default
=
''
,
help
=
'path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)'
)
parser
.
add_argument
(
'--wandb-project'
,
metavar
=
'DIR'
,
default
=
''
,
help
=
'name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY'
)
parser
.
add_argument
(
'--wandb-name'
,
metavar
=
'DIR'
,
default
=
''
,
help
=
'wandb run/id name, empty for no wandb logging, for wandb login, use env WANDB_API_KEY'
)
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
parser
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'use CPU instead of CUDA'
)
...
...
@@ -216,7 +227,7 @@ def get_parser(desc, default_task='test'):
def
add_dataset_args
(
parser
,
train
=
False
,
gen
=
False
):
group
=
parser
.
add_argument_group
(
'
Dataset and data loading
'
)
group
=
parser
.
add_argument_group
(
"
Dataset and data loading
"
)
# fmt: off
group
.
add_argument
(
'--num-workers'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'how many subprocesses to use for data loading'
)
...
...
@@ -256,7 +267,7 @@ def add_dataset_args(parser, train=False, gen=False):
def
add_distributed_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
'
Distributed training
'
)
group
=
parser
.
add_argument_group
(
"
Distributed training
"
)
# fmt: off
group
.
add_argument
(
'--distributed-world-size'
,
type
=
int
,
metavar
=
'N'
,
default
=
max
(
1
,
torch
.
cuda
.
device_count
()),
...
...
@@ -301,7 +312,7 @@ def add_distributed_training_args(parser):
def
add_optimization_args
(
parser
):
group
=
parser
.
add_argument_group
(
'
Optimization
'
)
group
=
parser
.
add_argument_group
(
"
Optimization
"
)
# fmt: off
group
.
add_argument
(
'--max-epoch'
,
'--me'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified epoch'
)
...
...
@@ -327,7 +338,7 @@ def add_optimization_args(parser):
def
add_checkpoint_args
(
parser
):
group
=
parser
.
add_argument_group
(
'
Checkpointing
'
)
group
=
parser
.
add_argument_group
(
"
Checkpointing
"
)
# fmt: off
group
.
add_argument
(
'--save-dir'
,
metavar
=
'DIR'
,
default
=
'checkpoints'
,
help
=
'path to save checkpoints'
)
...
...
@@ -397,7 +408,7 @@ def add_common_eval_args(group):
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'
Model configuration
'
)
group
=
parser
.
add_argument_group
(
"
Model configuration
"
)
# fmt: off
# Model definitions can be found under unicore/models/
...
...
unicore_cli/train.py
View file @
2376ef4a
...
...
@@ -40,7 +40,6 @@ logger = logging.getLogger("unicore_cli.train")
def
main
(
args
)
->
None
:
utils
.
import_user_module
(
args
)
utils
.
set_jit_fusion_options
()
...
...
@@ -84,17 +83,17 @@ def main(args) -> None:
logger
.
info
(
"num. model params: {:,} (num. trained: {:,})"
.
format
(
sum
(
getattr
(
p
,
"_orig_size"
,
p
).
numel
()
for
p
in
model
.
parameters
()),
sum
(
getattr
(
p
,
"_orig_size"
,
p
).
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
),
sum
(
getattr
(
p
,
"_orig_size"
,
p
).
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
),
)
)
# Build trainer
trainer
=
Trainer
(
args
,
task
,
model
,
loss
)
logger
.
info
(
"training on {} devices (GPUs)"
.
format
(
args
.
distributed_world_size
)
)
logger
.
info
(
"training on {} devices (GPUs)"
.
format
(
args
.
distributed_world_size
))
logger
.
info
(
"batch size per device = {}"
.
format
(
args
.
batch_size
,
...
...
@@ -123,7 +122,9 @@ def main(args) -> None:
break
# train for one epoch
valid_losses
,
should_stop
=
train
(
args
,
trainer
,
task
,
epoch_itr
,
ckp_copy_thread
)
valid_losses
,
should_stop
=
train
(
args
,
trainer
,
task
,
epoch_itr
,
ckp_copy_thread
)
if
should_stop
:
break
...
...
@@ -194,11 +195,13 @@ def train(
log_interval
=
args
.
log_interval
,
epoch
=
epoch_itr
.
epoch
,
tensorboard_logdir
=
(
args
.
tensorboard_logdir
if
distributed_utils
.
is_master
(
args
)
else
None
args
.
tensorboard_logdir
if
distributed_utils
.
is_master
(
args
)
else
None
),
wandb_project
=
(
args
.
wandb_project
if
distributed_utils
.
is_master
(
args
)
else
None
),
default_log_format
=
(
"tqdm"
if
not
args
.
no_progress_bar
else
"simple"
),
args
=
args
,
)
trainer
.
begin_epoch
(
epoch_itr
.
epoch
)
...
...
@@ -267,10 +270,7 @@ def validate_and_save(
)
training_time_hours
=
trainer
.
cumulative_training_time
()
/
(
60
*
60
)
if
(
args
.
stop_time_hours
>
0
and
training_time_hours
>
args
.
stop_time_hours
):
if
args
.
stop_time_hours
>
0
and
training_time_hours
>
args
.
stop_time_hours
:
should_stop
=
True
logger
.
info
(
f
"Stopping training due to "
...
...
@@ -279,7 +279,11 @@ def validate_and_save(
)
do_save
=
(
(
end_of_epoch
and
epoch_itr
.
epoch
%
args
.
save_interval
==
0
and
not
args
.
no_epoch_checkpoints
)
(
end_of_epoch
and
epoch_itr
.
epoch
%
args
.
save_interval
==
0
and
not
args
.
no_epoch_checkpoints
)
or
should_stop
or
(
args
.
save_interval_updates
>
0
...
...
@@ -290,7 +294,11 @@ def validate_and_save(
)
do_validate
=
(
(
not
end_of_epoch
and
do_save
)
# validate during mid-epoch saves
or
(
end_of_epoch
and
epoch_itr
.
epoch
%
args
.
validate_interval
==
0
and
not
args
.
no_epoch_checkpoints
)
or
(
end_of_epoch
and
epoch_itr
.
epoch
%
args
.
validate_interval
==
0
and
not
args
.
no_epoch_checkpoints
)
or
should_stop
or
(
args
.
validate_interval_updates
>
0
...
...
@@ -309,7 +317,12 @@ def validate_and_save(
# Save checkpoint
checkpoint_utils
.
save_checkpoint
(
args
,
trainer
,
epoch_itr
,
valid_losses
[
0
],
ckp_copy_thread
,
do_save
=
(
do_save
or
should_stop
),
args
,
trainer
,
epoch_itr
,
valid_losses
[
0
],
ckp_copy_thread
,
do_save
=
(
do_save
or
should_stop
),
)
return
valid_losses
,
should_stop
...
...
@@ -377,11 +390,12 @@ def validate(
return
valid_losses
def
get_valid_stats
(
args
,
trainer
:
Trainer
,
stats
:
Dict
[
str
,
Any
]
)
->
Dict
[
str
,
Any
]:
def
get_valid_stats
(
args
,
trainer
:
Trainer
,
stats
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
stats
[
"num_updates"
]
=
trainer
.
get_num_updates
()
if
hasattr
(
checkpoint_utils
.
save_checkpoint
,
"best"
)
and
args
.
best_checkpoint_metric
in
stats
:
if
(
hasattr
(
checkpoint_utils
.
save_checkpoint
,
"best"
)
and
args
.
best_checkpoint_metric
in
stats
):
key
=
"best_{0}"
.
format
(
args
.
best_checkpoint_metric
)
best_function
=
max
if
args
.
maximize_best_checkpoint_metric
else
min
stats
[
key
]
=
best_function
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment