Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a9ecb4b2
Unverified
Commit
a9ecb4b2
authored
Mar 22, 2022
by
ver217
Committed by
GitHub
Mar 22, 2022
Browse files
[zero] polish sharded optimizer v2 (#490)
parent
62b0a8d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
13 deletions
+16
-13
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+16
-13
No files found.
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
a9ecb4b2
...
@@ -110,19 +110,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -110,19 +110,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
def
step
(
self
,
*
args
,
**
kwargs
):
def
step
(
self
,
*
args
,
**
kwargs
):
if
self
.
_should_move_fp32_shards_h2d
:
self
.
_maybe_move_fp32_shards
()
self
.
_should_move_fp32_shards_h2d
=
False
available_cuda_margin_mem
=
self
.
model
.
cuda_margin_space
*
self
.
gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem
=
available_cuda_margin_mem
/
self
.
optim
.
num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem
=
0
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
shard_mem
=
self
.
master_params
[
p
].
numel
()
*
self
.
master_params
[
p
].
element_size
()
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
col_attr
.
offload_fp32_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
# unscale grads if scaled
# unscale grads if scaled
if
self
.
optim_state
==
OptimState
.
SCALED
:
if
self
.
optim_state
==
OptimState
.
SCALED
:
...
@@ -223,3 +211,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -223,3 +211,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def
sync_grad
(
self
):
def
sync_grad
(
self
):
pass
pass
def
_maybe_move_fp32_shards
(
self
):
if
self
.
_should_move_fp32_shards_h2d
:
self
.
_should_move_fp32_shards_h2d
=
False
available_cuda_margin_mem
=
self
.
model
.
cuda_margin_space
*
self
.
gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem
=
available_cuda_margin_mem
/
self
.
optim
.
num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem
=
0
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
shard_mem
=
self
.
master_params
[
p
].
numel
()
*
self
.
master_params
[
p
].
element_size
()
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
col_attr
.
offload_fp32_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment