Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
85dea5b2
Unverified
Commit
85dea5b2
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] SDP - adding the profiler labels (#630)
* adding the labels * longer labels, following aten::
parent
38ce54b7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
111 deletions
+117
-111
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+117
-111
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
85dea5b2
...
...
@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import
torch
from
torch
import
nn
from
torch.autograd
import
Variable
import
torch.autograd.profiler
as
profiler
import
torch.distributed
as
dist
from
fairscale.nn.misc
import
GradBucket
...
...
@@ -199,6 +200,7 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks.
"""
with
profiler
.
record_function
(
"fairscale::sdp::forward"
):
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
...
...
@@ -274,6 +276,7 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
with
profiler
.
record_function
(
"fairscale::sdp::refresh_trainable"
):
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
...
...
@@ -320,6 +323,7 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
"""
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
...
...
@@ -480,7 +484,7 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
"""
with
profiler
.
record_function
(
"fairscale::sdp::setup_backward_hooks"
):
# Detach possible pre-existing hooks
while
len
(
self
.
_grad_hooks
)
>
0
:
self
.
_grad_hooks
.
pop
().
remove
()
...
...
@@ -552,6 +556,7 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
with
profiler
.
record_function
(
"fairscale::sdp::setup_buckets"
):
if
not
self
.
_use_buckets
:
return
...
...
@@ -628,6 +633,7 @@ class ShardedDataParallel(nn.Module):
self
.
_consume_work_handles
()
def
_detect_train_change
(
self
)
->
bool
:
with
profiler
.
record_function
(
"fairscale::sdp::detect_train_changes"
):
# Optionally check whether the trainable parameters have changed
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment