Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
85dea5b2
Unverified
Commit
85dea5b2
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] SDP - adding the profiler labels (#630)
* adding the labels * longer labels, following aten::
parent
38ce54b7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
111 deletions
+117
-111
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+117
-111
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
85dea5b2
...
@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
...
@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
import
torch.autograd.profiler
as
profiler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
fairscale.nn.misc
import
GradBucket
from
fairscale.nn.misc
import
GradBucket
...
@@ -199,6 +200,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -199,6 +200,7 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks.
backward pass for gradient reduction to the proper ranks.
"""
"""
with
profiler
.
record_function
(
"fairscale::sdp::forward"
):
# Deferred initialization, or change detection
# Deferred initialization, or change detection
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
needs_setup
=
len
(
self
.
_grad_hooks
)
==
0
and
self
.
training
...
@@ -274,6 +276,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -274,6 +276,7 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
)
with
profiler
.
record_function
(
"fairscale::sdp::refresh_trainable"
):
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
self
.
_trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
self
.
_all_params
))
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
self
.
_trainable_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
...
@@ -320,6 +323,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -320,6 +323,7 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
blocking (bool): wait for the operation to conclude.
"""
"""
with
profiler
.
record_function
(
"fairscale::sdp::sync_buffers"
):
work_handles
=
[]
work_handles
=
[]
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
for
buffer
in
self
.
_module
.
buffers
(
recurse
=
True
):
...
@@ -480,7 +484,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -480,7 +484,7 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter.
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
This makes the gradient reduction automatic whenever there's a backward pass
"""
"""
with
profiler
.
record_function
(
"fairscale::sdp::setup_backward_hooks"
):
# Detach possible pre-existing hooks
# Detach possible pre-existing hooks
while
len
(
self
.
_grad_hooks
)
>
0
:
while
len
(
self
.
_grad_hooks
)
>
0
:
self
.
_grad_hooks
.
pop
().
remove
()
self
.
_grad_hooks
.
pop
().
remove
()
...
@@ -552,6 +556,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -552,6 +556,7 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
"""
with
profiler
.
record_function
(
"fairscale::sdp::setup_buckets"
):
if
not
self
.
_use_buckets
:
if
not
self
.
_use_buckets
:
return
return
...
@@ -628,6 +633,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -628,6 +633,7 @@ class ShardedDataParallel(nn.Module):
self
.
_consume_work_handles
()
self
.
_consume_work_handles
()
def
_detect_train_change
(
self
)
->
bool
:
def
_detect_train_change
(
self
)
->
bool
:
with
profiler
.
record_function
(
"fairscale::sdp::detect_train_changes"
):
# Optionally check whether the trainable parameters have changed
# Optionally check whether the trainable parameters have changed
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
trainable_mask
=
list
(
map
(
_trainable
,
self
.
_all_params
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment