Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3d7dc46d
"googlemock/include/vscode:/vscode.git/clone" did not exist on "490554aa0f3618e1e5dd217f11fe0c3f188ed615"
Unverified
Commit
3d7dc46d
authored
Apr 14, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 14, 2022
Browse files
[zero] use factory pattern for tensor_placement_policy (#752)
parent
4b048a87
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
14 deletions
+21
-14
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+3
-5
colossalai/zero/utils/__init__.py
colossalai/zero/utils/__init__.py
+2
-1
colossalai/zero/utils/tensor_placement_policy.py
colossalai/zero/utils/tensor_placement_policy.py
+16
-8
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
3d7dc46d
...
...
@@ -23,7 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.zero.utils.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.zero.utils.tensor_placement_policy
import
T
ENSOR_PLACEMENT_POLICIES
,
TensorPlacementPolicy
from
colossalai.zero.utils.tensor_placement_policy
import
T
ensorPlacementPolicyFactory
,
TensorPlacementPolicy
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
...
...
@@ -105,8 +105,6 @@ class ShardedModelV2(nn.Module):
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
shard_strategy
=
shard_strategy
assert
tensor_placement_policy
in
TENSOR_PLACEMENT_POLICIES
,
f
'Invalid tensor_placement_policy, got
{
tensor_placement_policy
}
'
# Init Memory Statistics Collector
self
.
_use_memory_tracer
=
tensor_placement_policy
==
'auto'
if
self
.
_use_memory_tracer
:
GLOBAL_MODEL_DATA_TRACER
.
register_model
(
self
)
...
...
@@ -115,8 +113,8 @@ class ShardedModelV2(nn.Module):
self
.
_finish_collect_memstats
=
disposable
(
self
.
_memstats_collector
.
finish_collection
)
else
:
self
.
_memstats_collector
=
None
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
T
ENSOR_PLACEMENT_POLICIES
[
t
ensor
_p
lacement
_p
olicy
]
(
mem_stats_collector
=
self
.
_memstats_collector
)
self
.
_tensor_placement_policy
:
TensorPlacementPolicy
=
Tensor
P
lacement
P
olicy
Factory
.
create
(
tensor_placement_policy
)(
mem_stats_collector
=
self
.
_memstats_collector
)
self
.
_stateful_tensor_mgr
=
StatefulTensorMgr
(
self
.
_tensor_placement_policy
)
for
param
in
module
.
parameters
():
if
hasattr
(
param
,
'colo_attr'
):
...
...
colossalai/zero/utils/__init__.py
View file @
3d7dc46d
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.tensor_placement_policy
import
TensorPlacementPolicyFactory
from
.zero_hook
import
ZeroHook
__all__
=
[
'StatefulTensorMgr'
,
'ZeroHook'
]
\ No newline at end of file
__all__
=
[
'StatefulTensorMgr'
,
'ZeroHook'
,
'TensorPlacementPolicyFactory'
]
\ No newline at end of file
colossalai/zero/utils/tensor_placement_policy.py
View file @
3d7dc46d
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Dict
import
torch
from
colossalai.utils
import
get_current_device
...
...
@@ -6,16 +7,16 @@ from colossalai.utils.memory import colo_device_memory_capacity
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
from
colossalai.utils.memory_tracer
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Type
__all__
=
[
'TENSOR_PLACEMENT_POLICIES'
]
class
TensorPlacementPolicy
:
class
TensorPlacementPolicy
(
ABC
):
def
__init__
(
self
,
device
:
Optional
[
torch
.
device
],
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
self
.
device
:
Optional
[
torch
.
device
]
=
device
self
.
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
mem_stats_collector
@
abstractmethod
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
**
kwargs
)
->
None
:
raise
NotImplementedError
...
...
@@ -87,8 +88,15 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
)
TENSOR_PLACEMENT_POLICIES
=
{
'cpu'
:
CPUTensorPlacementPolicy
,
'cuda'
:
CUDATensorPlacementPolicy
,
'auto'
:
AutoTensorPlacementPolicy
}
class
TensorPlacementPolicyFactory
:
@
staticmethod
def
create
(
policy_name
:
str
)
->
Type
[
TensorPlacementPolicy
]:
if
policy_name
==
'cpu'
:
return
CPUTensorPlacementPolicy
elif
policy_name
==
'cuda'
:
return
CUDATensorPlacementPolicy
elif
policy_name
==
'auto'
:
return
AutoTensorPlacementPolicy
else
:
raise
TypeError
(
f
"Unknown tensor placement policy
{
policy_name
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment