Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
d9e76fe9
Unverified
Commit
d9e76fe9
authored
Dec 05, 2021
by
jihan.yang
Committed by
GitHub
Dec 05, 2021
Browse files
support recording batch time during training (#697)
parent
7b27d8e7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
228 additions
and
1 deletion
+228
-1
pcdet/utils/common_utils.py
pcdet/utils/common_utils.py
+17
-0
pcdet/utils/commu_utils.py
pcdet/utils/commu_utils.py
+182
-0
tools/train_utils/train_utils.py
tools/train_utils/train_utils.py
+29
-1
No files found.
pcdet/utils/common_utils.py
View file @
d9e76fe9
...
@@ -245,3 +245,20 @@ def sa_create(name, var):
...
@@ -245,3 +245,20 @@ def sa_create(name, var):
x
.
flags
.
writeable
=
False
x
.
flags
.
writeable
=
False
return
x
return
x
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
pcdet/utils/commu_utils.py
0 → 100644
View file @
d9e76fe9
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
deeply borrow from maskrcnn-benchmark and ST3D
"""
import
pickle
import
time
import
torch
import
torch.distributed
as
dist
def
get_world_size
():
if
not
dist
.
is_available
():
return
1
if
not
dist
.
is_initialized
():
return
1
return
dist
.
get_world_size
()
def
get_rank
():
if
not
dist
.
is_available
():
return
0
if
not
dist
.
is_initialized
():
return
0
return
dist
.
get_rank
()
def
is_main_process
():
return
get_rank
()
==
0
def
synchronize
():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if
not
dist
.
is_available
():
return
if
not
dist
.
is_initialized
():
return
world_size
=
dist
.
get_world_size
()
if
world_size
==
1
:
return
dist
.
barrier
()
def
all_gather
(
data
):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size
=
get_world_size
()
if
world_size
==
1
:
return
[
data
]
# serialized to a Tensor
origin_size
=
None
if
not
isinstance
(
data
,
torch
.
Tensor
):
buffer
=
pickle
.
dumps
(
data
)
storage
=
torch
.
ByteStorage
.
from_buffer
(
buffer
)
tensor
=
torch
.
ByteTensor
(
storage
).
to
(
"cuda"
)
else
:
origin_size
=
data
.
size
()
tensor
=
data
.
reshape
(
-
1
)
tensor_type
=
tensor
.
dtype
# obtain Tensor size of each rank
local_size
=
torch
.
LongTensor
([
tensor
.
numel
()]).
to
(
"cuda"
)
size_list
=
[
torch
.
LongTensor
([
0
]).
to
(
"cuda"
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
size_list
,
local_size
)
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
max_size
=
max
(
size_list
)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list
=
[]
for
_
in
size_list
:
tensor_list
.
append
(
torch
.
FloatTensor
(
size
=
(
max_size
,)).
cuda
().
to
(
tensor_type
))
if
local_size
!=
max_size
:
padding
=
torch
.
FloatTensor
(
size
=
(
max_size
-
local_size
,)).
cuda
().
to
(
tensor_type
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
dist
.
all_gather
(
tensor_list
,
tensor
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
if
origin_size
is
None
:
buffer
=
tensor
.
cpu
().
numpy
().
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
else
:
buffer
=
tensor
[:
size
]
data_list
.
append
(
buffer
)
if
origin_size
is
not
None
:
new_shape
=
[
-
1
]
+
list
(
origin_size
[
1
:])
resized_list
=
[]
for
data
in
data_list
:
# suppose the difference of tensor size exist in first dimension
data
=
data
.
reshape
(
new_shape
)
resized_list
.
append
(
data
)
return
resized_list
else
:
return
data_list
def
reduce_dict
(
input_dict
,
average
=
True
):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size
=
get_world_size
()
if
world_size
<
2
:
return
input_dict
with
torch
.
no_grad
():
names
=
[]
values
=
[]
# sort the keys so that they are consistent across processes
for
k
in
sorted
(
input_dict
.
keys
()):
names
.
append
(
k
)
values
.
append
(
input_dict
[
k
])
values
=
torch
.
stack
(
values
,
dim
=
0
)
dist
.
reduce
(
values
,
dst
=
0
)
if
dist
.
get_rank
()
==
0
and
average
:
# only main process gets accumulated, so only divide by
# world_size in this case
values
/=
world_size
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
def
average_reduce_value
(
data
):
data_list
=
all_gather
(
data
)
return
sum
(
data_list
)
/
len
(
data_list
)
def
all_reduce
(
data
,
op
=
"sum"
,
average
=
False
):
def
op_map
(
op
):
op_dict
=
{
"SUM"
:
dist
.
ReduceOp
.
SUM
,
"MAX"
:
dist
.
ReduceOp
.
MAX
,
"MIN"
:
dist
.
ReduceOp
.
MIN
,
"PRODUCT"
:
dist
.
ReduceOp
.
PRODUCT
,
}
return
op_dict
[
op
]
world_size
=
get_world_size
()
if
world_size
>
1
:
reduced_data
=
data
.
clone
()
dist
.
all_reduce
(
reduced_data
,
op
=
op_map
(
op
.
upper
()))
if
average
:
assert
op
.
upper
()
==
'SUM'
return
reduced_data
/
world_size
else
:
return
reduced_data
return
data
@
torch
.
no_grad
()
def
concat_all_gather
(
tensor
):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather
=
[
torch
.
ones_like
(
tensor
)
for
_
in
range
(
torch
.
distributed
.
get_world_size
())]
torch
.
distributed
.
all_gather
(
tensors_gather
,
tensor
,
async_op
=
False
)
output
=
torch
.
cat
(
tensors_gather
,
dim
=
0
)
return
output
tools/train_utils/train_utils.py
View file @
d9e76fe9
...
@@ -3,7 +3,9 @@ import os
...
@@ -3,7 +3,9 @@ import os
import
torch
import
torch
import
tqdm
import
tqdm
import
time
from
torch.nn.utils
import
clip_grad_norm_
from
torch.nn.utils
import
clip_grad_norm_
from
pcdet.utils
import
common_utils
,
commu_utils
def
train_one_epoch
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
,
accumulated_iter
,
optim_cfg
,
def
train_one_epoch
(
model
,
optimizer
,
train_loader
,
model_func
,
lr_scheduler
,
accumulated_iter
,
optim_cfg
,
...
@@ -13,14 +15,21 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
...
@@ -13,14 +15,21 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
if
rank
==
0
:
if
rank
==
0
:
pbar
=
tqdm
.
tqdm
(
total
=
total_it_each_epoch
,
leave
=
leave_pbar
,
desc
=
'train'
,
dynamic_ncols
=
True
)
pbar
=
tqdm
.
tqdm
(
total
=
total_it_each_epoch
,
leave
=
leave_pbar
,
desc
=
'train'
,
dynamic_ncols
=
True
)
data_time
=
common_utils
.
AverageMeter
()
batch_time
=
common_utils
.
AverageMeter
()
forward_time
=
common_utils
.
AverageMeter
()
for
cur_it
in
range
(
total_it_each_epoch
):
for
cur_it
in
range
(
total_it_each_epoch
):
end
=
time
.
time
()
try
:
try
:
batch
=
next
(
dataloader_iter
)
batch
=
next
(
dataloader_iter
)
except
StopIteration
:
except
StopIteration
:
dataloader_iter
=
iter
(
train_loader
)
dataloader_iter
=
iter
(
train_loader
)
batch
=
next
(
dataloader_iter
)
batch
=
next
(
dataloader_iter
)
print
(
'new iters'
)
print
(
'new iters'
)
data_timer
=
time
.
time
()
cur_data_time
=
data_timer
-
end
lr_scheduler
.
step
(
accumulated_iter
)
lr_scheduler
.
step
(
accumulated_iter
)
...
@@ -37,12 +46,31 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
...
@@ -37,12 +46,31 @@ def train_one_epoch(model, optimizer, train_loader, model_func, lr_scheduler, ac
loss
,
tb_dict
,
disp_dict
=
model_func
(
model
,
batch
)
loss
,
tb_dict
,
disp_dict
=
model_func
(
model
,
batch
)
forward_timer
=
time
.
time
()
cur_forward_time
=
forward_timer
-
data_timer
loss
.
backward
()
loss
.
backward
()
clip_grad_norm_
(
model
.
parameters
(),
optim_cfg
.
GRAD_NORM_CLIP
)
clip_grad_norm_
(
model
.
parameters
(),
optim_cfg
.
GRAD_NORM_CLIP
)
optimizer
.
step
()
optimizer
.
step
()
accumulated_iter
+=
1
accumulated_iter
+=
1
disp_dict
.
update
({
'loss'
:
loss
.
item
(),
'lr'
:
cur_lr
})
cur_batch_time
=
time
.
time
()
-
end
# average reduce
avg_data_time
=
commu_utils
.
average_reduce_value
(
cur_data_time
)
avg_forward_time
=
commu_utils
.
average_reduce_value
(
cur_forward_time
)
avg_batch_time
=
commu_utils
.
average_reduce_value
(
cur_batch_time
)
if
rank
==
0
:
data_time
.
update
(
avg_data_time
)
forward_time
.
update
(
avg_forward_time
)
batch_time
.
update
(
avg_batch_time
)
disp_dict
.
update
({
'loss'
:
loss
.
item
(),
'lr'
:
cur_lr
,
'd_time'
:
f
'
{
data_time
.
val
:.
2
f
}
(
{
data_time
.
avg
:.
2
f
}
)'
,
'f_time'
:
f
'
{
forward_time
.
val
:.
2
f
}
(
{
forward_time
.
avg
:.
2
f
}
)'
,
'b_time'
:
f
'
{
batch_time
.
val
:.
2
f
}
(
{
batch_time
.
avg
:.
2
f
}
)'
})
# log to console and tensorboard
# log to console and tensorboard
if
rank
==
0
:
if
rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment