Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
arcface_pytorch
Commits
9b87ce51
Commit
9b87ce51
authored
May 30, 2024
by
dongchy920
Browse files
arcface
parents
Changes
84
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
309 additions
and
0 deletions
+309
-0
utils/utils_callbacks.py
utils/utils_callbacks.py
+125
-0
utils/utils_config.py
utils/utils_config.py
+17
-0
utils/utils_distributed_sampler.py
utils/utils_distributed_sampler.py
+126
-0
utils/utils_logging.py
utils/utils_logging.py
+41
-0
No files found.
utils/utils_callbacks.py
0 → 100644
View file @
9b87ce51
import
logging
import
os
import
time
from
typing
import
List
import
torch
from
eval
import
verification
from
utils.utils_logging
import
AverageMeter
from
torch.utils.tensorboard
import
SummaryWriter
from
torch
import
distributed
class
CallBackVerification
(
object
):
def
__init__
(
self
,
val_targets
,
rec_prefix
,
summary_writer
=
None
,
image_size
=
(
112
,
112
),
wandb_logger
=
None
):
self
.
rank
:
int
=
distributed
.
get_rank
()
self
.
highest_acc
:
float
=
0.0
self
.
highest_acc_list
:
List
[
float
]
=
[
0.0
]
*
len
(
val_targets
)
self
.
ver_list
:
List
[
object
]
=
[]
self
.
ver_name_list
:
List
[
str
]
=
[]
if
self
.
rank
is
0
:
self
.
init_dataset
(
val_targets
=
val_targets
,
data_dir
=
rec_prefix
,
image_size
=
image_size
)
self
.
summary_writer
=
summary_writer
self
.
wandb_logger
=
wandb_logger
def
ver_test
(
self
,
backbone
:
torch
.
nn
.
Module
,
global_step
:
int
):
results
=
[]
for
i
in
range
(
len
(
self
.
ver_list
)):
acc1
,
std1
,
acc2
,
std2
,
xnorm
,
embeddings_list
=
verification
.
test
(
self
.
ver_list
[
i
],
backbone
,
10
,
10
)
logging
.
info
(
'[%s][%d]XNorm: %f'
%
(
self
.
ver_name_list
[
i
],
global_step
,
xnorm
))
logging
.
info
(
'[%s][%d]Accuracy-Flip: %1.5f+-%1.5f'
%
(
self
.
ver_name_list
[
i
],
global_step
,
acc2
,
std2
))
self
.
summary_writer
:
SummaryWriter
self
.
summary_writer
.
add_scalar
(
tag
=
self
.
ver_name_list
[
i
],
scalar_value
=
acc2
,
global_step
=
global_step
,
)
if
self
.
wandb_logger
:
import
wandb
self
.
wandb_logger
.
log
({
f
'Acc/val-Acc1
{
self
.
ver_name_list
[
i
]
}
'
:
acc1
,
f
'Acc/val-Acc2
{
self
.
ver_name_list
[
i
]
}
'
:
acc2
,
# f'Acc/val-std1 {self.ver_name_list[i]}': std1,
# f'Acc/val-std2 {self.ver_name_list[i]}': acc2,
})
if
acc2
>
self
.
highest_acc_list
[
i
]:
self
.
highest_acc_list
[
i
]
=
acc2
logging
.
info
(
'[%s][%d]Accuracy-Highest: %1.5f'
%
(
self
.
ver_name_list
[
i
],
global_step
,
self
.
highest_acc_list
[
i
]))
results
.
append
(
acc2
)
def
init_dataset
(
self
,
val_targets
,
data_dir
,
image_size
):
for
name
in
val_targets
:
path
=
os
.
path
.
join
(
data_dir
,
name
+
".bin"
)
if
os
.
path
.
exists
(
path
):
data_set
=
verification
.
load_bin
(
path
,
image_size
)
self
.
ver_list
.
append
(
data_set
)
self
.
ver_name_list
.
append
(
name
)
def
__call__
(
self
,
num_update
,
backbone
:
torch
.
nn
.
Module
):
if
self
.
rank
is
0
and
num_update
>
0
:
backbone
.
eval
()
self
.
ver_test
(
backbone
,
num_update
)
backbone
.
train
()
class
CallBackLogging
(
object
):
def
__init__
(
self
,
frequent
,
total_step
,
batch_size
,
start_step
=
0
,
writer
=
None
):
self
.
frequent
:
int
=
frequent
self
.
rank
:
int
=
distributed
.
get_rank
()
self
.
world_size
:
int
=
distributed
.
get_world_size
()
self
.
time_start
=
time
.
time
()
self
.
total_step
:
int
=
total_step
self
.
start_step
:
int
=
start_step
self
.
batch_size
:
int
=
batch_size
self
.
writer
=
writer
self
.
init
=
False
self
.
tic
=
0
def
__call__
(
self
,
global_step
:
int
,
loss
:
AverageMeter
,
epoch
:
int
,
fp16
:
bool
,
learning_rate
:
float
,
grad_scaler
:
torch
.
cuda
.
amp
.
GradScaler
):
if
self
.
rank
==
0
and
global_step
>
0
and
global_step
%
self
.
frequent
==
0
:
if
self
.
init
:
try
:
speed
:
float
=
self
.
frequent
*
self
.
batch_size
/
(
time
.
time
()
-
self
.
tic
)
speed_total
=
speed
*
self
.
world_size
except
ZeroDivisionError
:
speed_total
=
float
(
'inf'
)
#time_now = (time.time() - self.time_start) / 3600
#time_total = time_now / ((global_step + 1) / self.total_step)
#time_for_end = time_total - time_now
time_now
=
time
.
time
()
time_sec
=
int
(
time_now
-
self
.
time_start
)
time_sec_avg
=
time_sec
/
(
global_step
-
self
.
start_step
+
1
)
eta_sec
=
time_sec_avg
*
(
self
.
total_step
-
global_step
-
1
)
time_for_end
=
eta_sec
/
3600
if
self
.
writer
is
not
None
:
self
.
writer
.
add_scalar
(
'time_for_end'
,
time_for_end
,
global_step
)
self
.
writer
.
add_scalar
(
'learning_rate'
,
learning_rate
,
global_step
)
self
.
writer
.
add_scalar
(
'loss'
,
loss
.
avg
,
global_step
)
if
fp16
:
msg
=
"Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d "
\
"Fp16 Grad Scale: %2.f Required: %1.f hours"
%
(
speed_total
,
loss
.
avg
,
learning_rate
,
epoch
,
global_step
,
grad_scaler
.
get_scale
(),
time_for_end
)
else
:
msg
=
"Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d "
\
"Required: %1.f hours"
%
(
speed_total
,
loss
.
avg
,
learning_rate
,
epoch
,
global_step
,
time_for_end
)
logging
.
info
(
msg
)
loss
.
reset
()
self
.
tic
=
time
.
time
()
else
:
self
.
init
=
True
self
.
tic
=
time
.
time
()
utils/utils_config.py
0 → 100644
View file @
9b87ce51
import
importlib
import
os.path
as
osp
def
get_config
(
config_file
):
assert
config_file
.
startswith
(
'configs/'
),
'config file setting must start with configs/'
temp_config_name
=
osp
.
basename
(
config_file
)
temp_module_name
=
osp
.
splitext
(
temp_config_name
)[
0
]
config
=
importlib
.
import_module
(
"configs.base"
)
cfg
=
config
.
config
config
=
importlib
.
import_module
(
"configs.%s"
%
temp_module_name
)
job_cfg
=
config
.
config
cfg
.
update
(
job_cfg
)
if
cfg
.
output
is
None
:
cfg
.
output
=
osp
.
join
(
'work_dirs'
,
temp_module_name
)
return
cfg
\ No newline at end of file
utils/utils_distributed_sampler.py
0 → 100644
View file @
9b87ce51
import
math
import
os
import
random
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
torch.utils.data
import
DistributedSampler
as
_DistributedSampler
def
setup_seed
(
seed
,
cuda_deterministic
=
True
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
if
cuda_deterministic
:
# slower, more reproducible
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
else
:
# faster, less reproducible
torch
.
backends
.
cudnn
.
deterministic
=
False
torch
.
backends
.
cudnn
.
benchmark
=
True
def
worker_init_fn
(
worker_id
,
num_workers
,
rank
,
seed
):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed
=
num_workers
*
rank
+
worker_id
+
seed
np
.
random
.
seed
(
worker_seed
)
random
.
seed
(
worker_seed
)
torch
.
manual_seed
(
worker_seed
)
def
get_dist_info
():
if
dist
.
is_available
()
and
dist
.
is_initialized
():
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
else
:
rank
=
0
world_size
=
1
return
rank
,
world_size
def
sync_random_seed
(
seed
=
None
,
device
=
"cuda"
):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if
seed
is
None
:
seed
=
np
.
random
.
randint
(
2
**
31
)
assert
isinstance
(
seed
,
int
)
rank
,
world_size
=
get_dist_info
()
if
world_size
==
1
:
return
seed
if
rank
==
0
:
random_num
=
torch
.
tensor
(
seed
,
dtype
=
torch
.
int32
,
device
=
device
)
else
:
random_num
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int32
,
device
=
device
)
dist
.
broadcast
(
random_num
,
src
=
0
)
return
random_num
.
item
()
class
DistributedSampler
(
_DistributedSampler
):
def
__init__
(
self
,
dataset
,
num_replicas
=
None
,
# world_size
rank
=
None
,
# local_rank
shuffle
=
True
,
seed
=
0
,
):
super
().
__init__
(
dataset
,
num_replicas
=
num_replicas
,
rank
=
rank
,
shuffle
=
shuffle
)
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self
.
seed
=
sync_random_seed
(
seed
)
def
__iter__
(
self
):
# deterministically shuffle based on epoch
if
self
.
shuffle
:
g
=
torch
.
Generator
()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g
.
manual_seed
(
self
.
epoch
+
self
.
seed
)
indices
=
torch
.
randperm
(
len
(
self
.
dataset
),
generator
=
g
).
tolist
()
else
:
indices
=
torch
.
arange
(
len
(
self
.
dataset
)).
tolist
()
# add extra samples to make it evenly divisible
# in case that indices is shorter than half of total_size
indices
=
(
indices
*
math
.
ceil
(
self
.
total_size
/
len
(
indices
)))[
:
self
.
total_size
]
assert
len
(
indices
)
==
self
.
total_size
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
return
iter
(
indices
)
utils/utils_logging.py
0 → 100644
View file @
9b87ce51
import
logging
import
os
import
sys
class
AverageMeter
(
object
):
"""Computes and stores the average and current value
"""
def
__init__
(
self
):
self
.
val
=
None
self
.
avg
=
None
self
.
sum
=
None
self
.
count
=
None
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
init_logging
(
rank
,
models_root
):
if
rank
==
0
:
log_root
=
logging
.
getLogger
()
log_root
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
"Training: %(asctime)s-%(message)s"
)
handler_file
=
logging
.
FileHandler
(
os
.
path
.
join
(
models_root
,
"training.log"
))
handler_stream
=
logging
.
StreamHandler
(
sys
.
stdout
)
handler_file
.
setFormatter
(
formatter
)
handler_stream
.
setFormatter
(
formatter
)
log_root
.
addHandler
(
handler_file
)
log_root
.
addHandler
(
handler_stream
)
log_root
.
info
(
'rank_id: %d'
%
rank
)
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment