Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9faad392
Unverified
Commit
9faad392
authored
Jan 08, 2021
by
Joshua Meier
Committed by
GitHub
Jan 07, 2021
Browse files
[feat] Support model parallelism in OSS (#287)
* add additional unit test * support model parallelism in oss
parent
53a912c3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
10 deletions
+34
-10
fairscale/optim/oss.py
fairscale/optim/oss.py
+18
-9
pyproject.toml
pyproject.toml
+1
-1
tests/optim/test_oss.py
tests/optim/test_oss.py
+15
-0
No files found.
fairscale/optim/oss.py
View file @
9faad392
...
...
@@ -216,7 +216,12 @@ class OSS(Optimizer):
return
loss
def
clip_grad_norm
(
self
,
max_norm
:
Union
[
float
,
int
],
norm_type
:
Union
[
float
,
int
]
=
2.0
)
->
torch
.
Tensor
:
def
clip_grad_norm
(
self
,
max_norm
:
Union
[
float
,
int
],
norm_type
:
Union
[
float
,
int
]
=
2.0
,
filter_params_fn
:
Callable
[[
Any
],
Any
]
=
None
,
)
->
torch
.
Tensor
:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
...
...
@@ -237,9 +242,6 @@ class OSS(Optimizer):
.. warning: Model paralelism -groups other than world- are not yet supported
"""
if
self
.
group
!=
dist
.
group
.
WORLD
:
raise
NotImplementedError
(
"Clip norm not yet supported for model parallelism (coming soon!)"
)
# Compute the max norm for this shards's worth of gradients
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
...
...
@@ -252,11 +254,19 @@ class OSS(Optimizer):
]
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if
filter_params_fn
is
not
None
:
local_params
=
filter_params_fn
(
local_params
)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if
norm_type
==
inf
:
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# type: ignore
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
group
)
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
else
:
local_norm
=
torch
.
norm
(
input
=
torch
.
stack
([
torch
.
norm
(
input
=
p
.
grad
.
detach
(),
p
=
norm_type
,
dtype
=
torch
.
float32
).
to
(
self
.
_device
)
for
p
in
local_params
]),
# type: ignore
...
...
@@ -266,12 +276,12 @@ class OSS(Optimizer):
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm
=
local_norm
**
norm_type
dist
.
all_reduce
(
total_norm
,
group
=
self
.
group
)
dist
.
all_reduce
(
total_norm
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
device
,
device_params
in
self
.
per_device_params
.
items
():
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
...
...
@@ -354,7 +364,6 @@ class OSS(Optimizer):
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
"""
# Get this optimizer's param_groups shard
param_groups
=
state_dict
[
"param_groups"
][
state_dict
[
"partition"
][
rank
][
0
]
:
state_dict
[
"partition"
][
rank
][
1
]]
return
{
"state"
:
state_dict
[
"state"
][
rank
],
"param_groups"
:
param_groups
}
...
...
@@ -402,7 +411,7 @@ class OSS(Optimizer):
"""
# Check whether we got a local or global dict
if
state_dict
[
"local_state_dict"
]:
if
"local_state_dict"
in
state_dict
and
state_dict
[
"local_state_dict"
]:
self
.
load_local_state_dict
(
state_dict
)
else
:
# Dispatch this rank's state dictionary to the wrapped shard optimizer
...
...
pyproject.toml
View file @
9faad392
...
...
@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob
=
[
"build/*"
,
"stubs/*"
]
# Don't split "import" and "from".
force_sort_within_sections
=
true
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"
helper
s"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"
datasets"
,
"helpers"
,
"model
s"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
tests/optim/test_oss.py
View file @
9faad392
...
...
@@ -604,6 +604,21 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
optimizer
.
step
()
optimizer
.
zero_grad
()
# save and reload without taking any steps
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# now take a step and check that parameters are equal
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before any steps)"
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment