Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9faad392
Unverified
Commit
9faad392
authored
Jan 08, 2021
by
Joshua Meier
Committed by
GitHub
Jan 07, 2021
Browse files
[feat] Support model parallelism in OSS (#287)
* add additional unit test * support model parallelism in oss
parent
53a912c3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
10 deletions
+34
-10
fairscale/optim/oss.py
fairscale/optim/oss.py
+18
-9
pyproject.toml
pyproject.toml
+1
-1
tests/optim/test_oss.py
tests/optim/test_oss.py
+15
-0
No files found.
fairscale/optim/oss.py
View file @
9faad392
...
@@ -216,7 +216,12 @@ class OSS(Optimizer):
...
@@ -216,7 +216,12 @@ class OSS(Optimizer):
return
loss
return
loss
def
clip_grad_norm
(
self
,
max_norm
:
Union
[
float
,
int
],
norm_type
:
Union
[
float
,
int
]
=
2.0
)
->
torch
.
Tensor
:
def
clip_grad_norm
(
self
,
max_norm
:
Union
[
float
,
int
],
norm_type
:
Union
[
float
,
int
]
=
2.0
,
filter_params_fn
:
Callable
[[
Any
],
Any
]
=
None
,
)
->
torch
.
Tensor
:
"""
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
concatenated into a single vector. Gradients are modified in-place.
...
@@ -237,9 +242,6 @@ class OSS(Optimizer):
...
@@ -237,9 +242,6 @@ class OSS(Optimizer):
.. warning: Model paralelism -groups other than world- are not yet supported
.. warning: Model paralelism -groups other than world- are not yet supported
"""
"""
if
self
.
group
!=
dist
.
group
.
WORLD
:
raise
NotImplementedError
(
"Clip norm not yet supported for model parallelism (coming soon!)"
)
# Compute the max norm for this shards's worth of gradients
# Compute the max norm for this shards's worth of gradients
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
norm_type
=
float
(
norm_type
)
...
@@ -252,11 +254,19 @@ class OSS(Optimizer):
...
@@ -252,11 +254,19 @@ class OSS(Optimizer):
]
]
)
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if
filter_params_fn
is
not
None
:
local_params
=
filter_params_fn
(
local_params
)
# Compute the norm on this grad set,
# Compute the norm on this grad set,
# then sync all the norms from all ranks
# then sync all the norms from all ranks
if
norm_type
==
inf
:
if
norm_type
==
inf
:
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# type: ignore
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# type: ignore
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
group
)
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
else
:
else
:
local_norm
=
torch
.
norm
(
local_norm
=
torch
.
norm
(
input
=
torch
.
stack
([
torch
.
norm
(
input
=
p
.
grad
.
detach
(),
p
=
norm_type
,
dtype
=
torch
.
float32
).
to
(
self
.
_device
)
for
p
in
local_params
]),
# type: ignore
input
=
torch
.
stack
([
torch
.
norm
(
input
=
p
.
grad
.
detach
(),
p
=
norm_type
,
dtype
=
torch
.
float32
).
to
(
self
.
_device
)
for
p
in
local_params
]),
# type: ignore
...
@@ -266,12 +276,12 @@ class OSS(Optimizer):
...
@@ -266,12 +276,12 @@ class OSS(Optimizer):
# local norm result can be accumulated with the remote ones if put to the right power
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm
=
local_norm
**
norm_type
total_norm
=
local_norm
**
norm_type
dist
.
all_reduce
(
total_norm
,
group
=
self
.
group
)
dist
.
all_reduce
(
total_norm
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
if
clip_coef
<
1
:
for
device
,
device_params
in
self
.
per_device_params
.
items
():
for
device
,
device_params
in
self
.
per_device_params
.
items
():
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
...
@@ -354,7 +364,6 @@ class OSS(Optimizer):
...
@@ -354,7 +364,6 @@ class OSS(Optimizer):
rank (int): rank to get local_state_dict for
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
state_dict (dict): global state_dict
"""
"""
# Get this optimizer's param_groups shard
param_groups
=
state_dict
[
"param_groups"
][
state_dict
[
"partition"
][
rank
][
0
]
:
state_dict
[
"partition"
][
rank
][
1
]]
param_groups
=
state_dict
[
"param_groups"
][
state_dict
[
"partition"
][
rank
][
0
]
:
state_dict
[
"partition"
][
rank
][
1
]]
return
{
"state"
:
state_dict
[
"state"
][
rank
],
"param_groups"
:
param_groups
}
return
{
"state"
:
state_dict
[
"state"
][
rank
],
"param_groups"
:
param_groups
}
...
@@ -402,7 +411,7 @@ class OSS(Optimizer):
...
@@ -402,7 +411,7 @@ class OSS(Optimizer):
"""
"""
# Check whether we got a local or global dict
# Check whether we got a local or global dict
if
state_dict
[
"local_state_dict"
]:
if
"local_state_dict"
in
state_dict
and
state_dict
[
"local_state_dict"
]:
self
.
load_local_state_dict
(
state_dict
)
self
.
load_local_state_dict
(
state_dict
)
else
:
else
:
# Dispatch this rank's state dictionary to the wrapped shard optimizer
# Dispatch this rank's state dictionary to the wrapped shard optimizer
...
...
pyproject.toml
View file @
9faad392
...
@@ -28,4 +28,4 @@ use_parentheses = true
...
@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob
=
[
"build/*"
,
"stubs/*"
]
skip_glob
=
[
"build/*"
,
"stubs/*"
]
# Don't split "import" and "from".
# Don't split "import" and "from".
force_sort_within_sections
=
true
force_sort_within_sections
=
true
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"
helper
s"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
known_third_party
=
[
"benchmark_dataset"
,
"dataclasses"
,
"
datasets"
,
"helpers"
,
"model
s"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
tests/optim/test_oss.py
View file @
9faad392
...
@@ -604,6 +604,21 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
...
@@ -604,6 +604,21 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# save and reload without taking any steps
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# now take a step and check that parameters are equal
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before any steps)"
# take a step
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment