Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a444633d
"vscode:/vscode.git/clone" did not exist on "81d817fe68e6ee09ba7e55bc1c203269a666dc7d"
Unverified
Commit
a444633d
authored
Jun 30, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 30, 2022
Browse files
warmup ratio configration (#1192)
parent
dba7e0cf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
1 deletion
+11
-1
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+11
-1
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
a444633d
...
...
@@ -76,7 +76,9 @@ class ShardedModelV2(nn.Module):
fp32_reduce_scatter
:
bool
=
False
,
tensor_placement_policy
:
str
=
'cuda'
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
,
reuse_fp16_shard
:
bool
=
False
):
reuse_fp16_shard
:
bool
=
False
,
*
args
,
**
kwargs
):
assert
not
isinstance
(
module
,
ShardedModelV2
),
'Nested ShardedModelV2 is not supported.'
super
().
__init__
()
self
.
logger
=
get_dist_logger
()
...
...
@@ -119,6 +121,14 @@ class ShardedModelV2(nn.Module):
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
TensorPlacementPolicyFactory
.
create
(
tensor_placement_policy
)(
mem_stats_collector
=
self
.
_memstats_collector
)
if
'warmup_non_model_data_ratio'
in
kwargs
:
if
tensor_placement_policy
!=
'auto'
:
self
.
logger
.
warning
(
'setting warmup_non_model_data_ratio is useless if not use auto placement'
)
else
:
ratio
=
kwargs
[
'warmup_non_model_data_ratio'
]
self
.
_tensor_placement_policy
.
_warmup_non_model_data_ratio
=
ratio
self
.
logger
.
info
(
f
'setting warmup_non_model_data_ratio as
{
ratio
}
for auto placement'
)
self
.
_stateful_tensor_mgr
=
StatefulTensorMgr
(
self
.
_tensor_placement_policy
)
param_tensor_list
=
[
p
.
colo_attr
.
sharded_data_tensor
for
p
in
module
.
parameters
()
if
hasattr
(
p
,
'colo_attr'
)]
self
.
_stateful_tensor_mgr
.
register_stateful_tensor_list
(
param_tensor_list
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment