Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d9332b4
"website/vscode:/vscode.git/clone" did not exist on "694ae2a7c687a0fb12dc09a4216a04347ceb6d1d"
Unverified
Commit
4d9332b4
authored
Apr 19, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 19, 2022
Browse files
[refactor] moving memtracer to gemini (#801)
parent
8711c706
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
98 additions
and
81 deletions
+98
-81
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+1
-2
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+9
-7
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+2
-1
colossalai/gemini/memory_tracer/memory_monitor.py
colossalai/gemini/memory_tracer/memory_monitor.py
+1
-1
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+6
-2
colossalai/gemini/memory_tracer/model_data_memtracer.py
colossalai/gemini/memory_tracer/model_data_memtracer.py
+0
-0
colossalai/gemini/tensor_placement_policy.py
colossalai/gemini/tensor_placement_policy.py
+2
-2
colossalai/trainer/hooks/_commons_.py
colossalai/trainer/hooks/_commons_.py
+9
-0
colossalai/trainer/hooks/_log_hook.py
colossalai/trainer/hooks/_log_hook.py
+21
-35
colossalai/trainer/hooks/_mem_tracer_hook.py
colossalai/trainer/hooks/_mem_tracer_hook.py
+1
-1
colossalai/trainer/hooks/_metric_hook.py
colossalai/trainer/hooks/_metric_hook.py
+30
-16
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+4
-3
docs/colossalai/colossalai.utils.memory_tracer.async_memtracer.rst
...ssalai/colossalai.utils.memory_tracer.async_memtracer.rst
+1
-1
docs/colossalai/colossalai.utils.memory_tracer.memstats_collector.rst
...lai/colossalai.utils.memory_tracer.memstats_collector.rst
+1
-1
docs/colossalai/colossalai.utils.memory_tracer.model_data_memtracer.rst
...i/colossalai.utils.memory_tracer.model_data_memtracer.rst
+1
-1
docs/colossalai/colossalai.utils.memory_tracer.rst
docs/colossalai/colossalai.utils.memory_tracer.rst
+4
-4
docs/colossalai/colossalai.utils.rst
docs/colossalai/colossalai.utils.rst
+1
-1
tests/test_data/test_deterministic_dataloader.py
tests/test_data/test_deterministic_dataloader.py
+1
-0
No files found.
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
4d9332b4
...
...
@@ -8,8 +8,6 @@ from colossalai.registry import OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
colossalai.core
import
global_context
as
gpc
from
typing
import
Union
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
import
os
import
math
...
...
@@ -25,6 +23,7 @@ class MemTracerOpHook(BaseOpHook):
"""
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
from
colossalai.gemini.memory_tracer
import
AsyncMemoryMonitor
super
().
__init__
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_curiter
=
0
...
...
colossalai/engine/schedule/_pipeline_schedule.py
View file @
4d9332b4
...
...
@@ -12,10 +12,10 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
._base_schedule
import
BaseSchedule
def
get_tensor_shape
():
if
hasattr
(
gpc
.
config
,
'TENSOR_SHAPE'
):
return
gpc
.
config
.
TENSOR_SHAPE
...
...
@@ -23,7 +23,8 @@ def get_tensor_shape():
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
):
return
None
if
hasattr
(
gpc
.
config
,
'SEQ_LENGTH'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'HIDDEN_SIZE'
):
if
hasattr
(
gpc
.
config
,
'SEQ_LENGTH'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'HIDDEN_SIZE'
):
if
gpc
.
is_initialized
(
ParallelMode
.
DATA
):
dp_size
=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
...
...
@@ -34,12 +35,12 @@ def get_tensor_shape():
seq_size
=
1
tensor_shape
=
(
gpc
.
config
.
SEQ_LENGTH
//
seq_size
,
gpc
.
config
.
GLOBAL_BATCH_SIZE
//
dp_size
//
gpc
.
config
.
NUM_MICRO_BATCHES
,
gpc
.
config
.
HIDDEN_SIZE
)
gpc
.
config
.
GLOBAL_BATCH_SIZE
//
dp_size
//
gpc
.
config
.
NUM_MICRO_BATCHES
,
gpc
.
config
.
HIDDEN_SIZE
)
return
tensor_shape
else
:
return
None
def
pack_return_tensors
(
return_tensors
):
output
,
label
=
tuple
(
zip
(
*
return_tensors
))
if
isinstance
(
output
[
0
],
torch
.
Tensor
):
...
...
@@ -114,7 +115,7 @@ class PipelineSchedule(BaseSchedule):
def
pre_processing
(
self
,
engine
):
# TODO: remove this after testing new zero with pipeline parallelism
model
=
engine
.
model
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)
):
if
isinstance
(
model
,
(
NaiveAMPModel
))
or
hasattr
(
model
,
'colo_attr'
):
self
.
dtype
=
torch
.
half
model
=
model
.
model
sig
=
inspect
.
signature
(
model
.
forward
)
...
...
@@ -125,7 +126,7 @@ class PipelineSchedule(BaseSchedule):
def
_call_engine
(
model
,
input_tensor
,
batch_data
):
if
isinstance
(
model
,
NaiveAMPModel
):
sig
=
inspect
.
signature
(
model
.
model
.
forward
)
elif
isinstance
(
model
,
ShardedModelV2
):
elif
hasattr
(
model
,
'colo_attr'
):
sig
=
inspect
.
signature
(
model
.
module
.
forward
)
else
:
sig
=
inspect
.
signature
(
model
.
forward
)
...
...
@@ -385,7 +386,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self
.
num_model_chunks
=
num_model_chunks
def
pre_processing
(
self
,
engine
):
if
isinstance
(
engine
.
model
,
ShardedModelV2
):
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
if
hasattr
(
engine
.
model
,
'colo_attr'
):
self
.
dtype
=
torch
.
half
elif
isinstance
(
engine
.
model
[
0
],
NaiveAMPModel
):
self
.
dtype
=
torch
.
half
...
...
colossalai/
utils
/memory_tracer/__init__.py
→
colossalai/
gemini
/memory_tracer/__init__.py
View file @
4d9332b4
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
colossalai/
utils
/memory_tracer/memory_monitor.py
→
colossalai/
gemini
/memory_tracer/memory_monitor.py
View file @
4d9332b4
...
...
@@ -5,7 +5,7 @@ import json
import
torch
from
colossalai.utils
.memory
import
colo_device_memory_used
from
colossalai.utils
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
...
...
colossalai/
utils
/memory_tracer/memstats_collector.py
→
colossalai/
gemini
/memory_tracer/memstats_collector.py
View file @
4d9332b4
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.gemini.memory_tracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory_tracer
import
SyncCudaMemoryMonitor
import
torch
import
time
from
typing
import
List
...
...
@@ -138,6 +139,9 @@ class MemStatsCollector:
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_total
=
0
colossalai/
utils
/memory_tracer/model_data_memtracer.py
→
colossalai/
gemini
/memory_tracer/model_data_memtracer.py
View file @
4d9332b4
File moved
colossalai/gemini/tensor_placement_policy.py
View file @
4d9332b4
...
...
@@ -5,8 +5,8 @@ from colossalai.utils import get_current_device
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
from
colossalai.
utils
.memory_tracer
import
MemStatsCollector
from
colossalai.
utils
.memory_tracer
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.
gemini
.memory_tracer
import
MemStatsCollector
from
colossalai.
gemini
.memory_tracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Type
...
...
colossalai/trainer/hooks/_commons_.py
0 → 100644
View file @
4d9332b4
import
torch
def
_format_number
(
val
,
prec
=
5
):
if
isinstance
(
val
,
float
):
return
f
'
{
val
:.
{
prec
}
g
}
'
elif
torch
.
is_tensor
(
val
)
and
torch
.
is_floating_point
(
val
):
return
f
'
{
val
.
item
():.
{
prec
}
g
}
'
return
val
colossalai/trainer/hooks/_log_hook.py
View file @
4d9332b4
...
...
@@ -14,14 +14,7 @@ from colossalai.logging import DistributedLogger
from
colossalai.utils
import
report_memory_usage
,
is_dp_rank_0
,
\
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
from
._base_hook
import
BaseHook
def
_format_number
(
val
,
prec
=
5
):
if
isinstance
(
val
,
float
):
return
f
'
{
val
:.
{
prec
}
g
}
'
elif
torch
.
is_tensor
(
val
)
and
torch
.
is_floating_point
(
val
):
return
f
'
{
val
.
item
():.
{
prec
}
g
}
'
return
val
from
._commons_
import
_format_number
class
LogByEpochHook
(
BaseHook
):
...
...
@@ -35,10 +28,7 @@ class LogByEpochHook(BaseHook):
depend on the hooks order in the hook list.
"""
def
__init__
(
self
,
logger
,
interval
:
int
=
1
,
priority
:
int
=
1
):
def
__init__
(
self
,
logger
,
interval
:
int
=
1
,
priority
:
int
=
1
):
super
().
__init__
(
priority
)
self
.
logger
=
logger
self
.
_interval
=
interval
...
...
@@ -63,14 +53,12 @@ class LogMetricByStepHook(BaseHook):
def
after_train_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'train'
].
items
():
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
\
f
'
{
_format_number
(
metric_calculator
.
get_last_step_value
())
}
'
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
def
after_test_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'test'
].
items
():
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
\
f
'
{
_format_number
(
metric_calculator
.
get_last_step_value
())
}
'
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
@
HOOKS
.
register_module
...
...
@@ -85,18 +73,14 @@ class LogMetricByEpochHook(LogByEpochHook):
depend on the hooks order in the hook list.
"""
def
__init__
(
self
,
logger
,
interval
:
int
=
1
,
priority
:
int
=
10
)
->
None
:
def
__init__
(
self
,
logger
,
interval
:
int
=
1
,
priority
:
int
=
10
)
->
None
:
super
().
__init__
(
logger
,
interval
,
priority
)
self
.
_is_rank_to_log
=
is_dp_rank_0
()
and
is_tp_rank_0
()
and
is_no_pp_or_last_stage
()
def
_get_str
(
self
,
trainer
,
mode
):
msg
=
[]
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
mode
].
items
():
msg
.
append
(
f
'
{
metric_name
}
=
{
_format_number
(
metric_calculator
.
get_accumulated_value
())
}
'
)
msg
.
append
(
f
'
{
metric_name
}
=
{
_format_number
(
metric_calculator
.
get_accumulated_value
())
}
'
)
msg
=
' | '
.
join
(
msg
)
return
msg
...
...
@@ -130,7 +114,8 @@ class TensorboardHook(BaseHook):
depend on the hooks order in the hook list.
"""
def
__init__
(
self
,
def
__init__
(
self
,
log_dir
:
str
,
ranks
:
List
=
None
,
parallel_mode
:
ParallelMode
=
ParallelMode
.
GLOBAL
,
...
...
@@ -280,7 +265,8 @@ class LogMemoryByEpochHook(LogByEpochHook):
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
"""
def
__init__
(
self
,
def
__init__
(
self
,
logger
:
DistributedLogger
,
interval
:
int
=
1
,
priority
:
int
=
10
,
...
...
colossalai/trainer/hooks/_mem_tracer_hook.py
View file @
4d9332b4
from
colossalai.registry
import
HOOKS
from
torch
import
Tensor
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.
utils
.memory_tracer
import
AsyncMemoryMonitor
from
colossalai.
gemini
.memory_tracer
import
AsyncMemoryMonitor
@
HOOKS
.
register_module
...
...
colossalai/trainer/hooks/_metric_hook.py
View file @
4d9332b4
...
...
@@ -13,6 +13,7 @@ from colossalai.registry import HOOKS
from
colossalai.utils
import
get_current_device
,
is_no_pp_or_last_stage
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
class
Metric
(
ABC
):
...
...
@@ -51,7 +52,7 @@ class Metric(ABC):
pass
@
abstractmethod
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
"""Returns the metric value in the last iteration.
"""
pass
...
...
@@ -120,10 +121,10 @@ class LossMetric(Metric):
self
.
accum_loss
.
div_
(
self
.
count
)
return
self
.
accum_loss
.
item
()
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
"""Returns :attr:`last_step_loss`.
"""
return
self
.
last_step_loss
return
str
(
self
.
last_step_loss
)
@
staticmethod
def
is_better
(
a
,
b
):
...
...
@@ -148,8 +149,8 @@ class LearningRateMetric(Metric):
def
update
(
self
,
lr
)
->
None
:
self
.
lr
=
lr
def
get_last_step_value
(
self
):
return
self
.
lr
def
get_last_step_value
(
self
)
->
str
:
return
str
(
self
.
lr
)
def
get_accumulated_value
(
self
):
return
self
.
lr
...
...
@@ -203,10 +204,10 @@ class AccuracyMetric(Metric):
self
.
accumulated_sum
+=
self
.
last_step_sum
self
.
accumulated_correct
+=
self
.
last_step_correct
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
self
.
last_step_sum
=
all_reduce
(
self
.
last_step_sum
,
ParallelMode
.
DATA
)
self
.
last_step_correct
=
all_reduce
(
self
.
last_step_correct
,
ParallelMode
.
DATA
)
return
(
self
.
last_step_correct
/
self
.
last_step_sum
).
item
()
return
str
(
_format_number
(
(
self
.
last_step_correct
/
self
.
last_step_sum
).
item
()
))
def
get_accumulated_value
(
self
):
self
.
accumulated_sum
=
all_reduce
(
self
.
accumulated_sum
,
ParallelMode
.
DATA
)
...
...
@@ -322,7 +323,8 @@ class ThroughputMetric(Metric):
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
"""
def
__init__
(
self
,
epoch_only
:
bool
,
ignored_steps
:
int
=
0
):
def
__init__
(
self
,
epoch_only
:
bool
,
ignored_steps
:
int
=
0
,
tflop_per_step
:
int
=
0
):
super
().
__init__
(
epoch_only
=
epoch_only
)
self
.
ignored_steps
=
ignored_steps
self
.
cur_steps
=
0
...
...
@@ -330,6 +332,7 @@ class ThroughputMetric(Metric):
self
.
accumulated_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_num_samples
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
_tflop_per_step
=
tflop_per_step
def
reset
(
self
)
->
None
:
# self.cur_steps = 0
...
...
@@ -346,13 +349,18 @@ class ThroughputMetric(Metric):
self
.
accumulated_num_samples
+=
self
.
last_step_num_samples
self
.
accumulated_used_time
+=
self
.
last_step_used_time
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
return
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
)).
item
()
def
get_accumulated_value
(
self
):
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
if
self
.
_tflop_per_step
>
0
:
tflops
=
_format_number
(
self
.
_tflop_per_step
/
(
self
.
last_step_used_time
.
item
()
+
1e-12
))
return
f
"
{
sample_per_sec
}
sample_per_sec,
{
tflops
}
Tflops"
else
:
return
f
"
{
sample_per_sec
}
sample_per_sec"
def
get_accumulated_value
(
self
)
->
float
:
self
.
accumulated_used_time
=
all_reduce
(
self
.
accumulated_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
accumulated_num_samples
=
all_reduce
(
self
.
accumulated_num_samples
,
ParallelMode
.
DATA
)
...
...
@@ -373,14 +381,18 @@ class ThroughputHook(MetricHook):
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
"""
def
__init__
(
self
,
ignored_steps
:
int
=
0
,
priority
:
int
=
10
):
def
__init__
(
self
,
ignored_steps
:
int
=
0
,
priority
:
int
=
10
,
tflop_per_step
:
int
=
0
):
super
().
__init__
(
priority
)
self
.
ignored_steps
=
ignored_steps
self
.
_tflop_per_step
=
tflop_per_step
def
after_hook_is_attached
(
self
,
trainer
):
self
.
_check_metric_states_initialization
(
trainer
)
if
self
.
_is_stage_to_compute
:
self
.
metric
=
ThroughputMetric
(
epoch_only
=
True
,
ignored_steps
=
self
.
ignored_steps
)
self
.
metric
=
ThroughputMetric
(
epoch_only
=
True
,
ignored_steps
=
self
.
ignored_steps
,
tflop_per_step
=
self
.
_tflop_per_step
)
# register the metric
trainer
.
states
[
'metrics'
][
'train'
][
'Throughput'
]
=
self
.
metric
...
...
@@ -392,7 +404,8 @@ class ThroughputHook(MetricHook):
def
after_train_iter
(
self
,
trainer
,
*
args
):
if
self
.
_is_stage_to_compute
:
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Train-step'
).
get_elapsed_time
())
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Train-step'
).
get_elapsed_time
())
def
before_test
(
self
,
trainer
):
if
self
.
_is_stage_to_compute
:
...
...
@@ -400,4 +413,5 @@ class ThroughputHook(MetricHook):
def
after_test_iter
(
self
,
trainer
,
*
args
):
if
self
.
_is_stage_to_compute
:
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Test-step'
).
get_elapsed_time
())
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Test-step'
).
get_elapsed_time
())
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
4d9332b4
...
...
@@ -12,8 +12,8 @@ from colossalai.zero.utils import ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.
utils
.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.
utils
.memory_tracer.model_data_memtracer
import
\
from
colossalai.
gemini
.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.
gemini
.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.shard_utils
import
BaseShardStrategy
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
4d9332b4
...
...
@@ -10,7 +10,7 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.
utils
.memory_tracer.model_data_memtracer
import
\
from
colossalai.
gemini
.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.sharded_param.tensor_utils
import
(
colo_model_data_tensor_move_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
...
...
colossalai/zero/utils/zero_hook.py
View file @
4d9332b4
...
...
@@ -5,14 +5,15 @@ import torch.distributed as dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
from
typing
import
Any
@
OPHOOKS
.
register_module
class
ZeroHook
(
BaseOpHook
):
...
...
docs/colossalai/colossalai.utils.memory_tracer.async_memtracer.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer.async\_memtracer
================================================
.. automodule:: colossalai.
utils
.memory_tracer.async_memtracer
.. automodule:: colossalai.
gemini
.memory_tracer.async_memtracer
:members:
docs/colossalai/colossalai.utils.memory_tracer.memstats_collector.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer.memstats\_collector
===================================================
.. automodule:: colossalai.
utils
.memory_tracer.memstats_collector
.. automodule:: colossalai.
gemini
.memory_tracer.memstats_collector
:members:
docs/colossalai/colossalai.utils.memory_tracer.model_data_memtracer.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer.model\_data\_memtracer
======================================================
.. automodule:: colossalai.
utils
.memory_tracer.model_data_memtracer
.. automodule:: colossalai.
gemini
.memory_tracer.model_data_memtracer
:members:
docs/colossalai/colossalai.utils.memory_tracer.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer
===============================
.. automodule:: colossalai.
utils
.memory_tracer
.. automodule:: colossalai.
gemini
.memory_tracer
:members:
.. toctree::
:maxdepth: 2
colossalai.
utils
.memory_tracer.async_memtracer
colossalai.
utils
.memory_tracer.memstats_collector
colossalai.
utils
.memory_tracer.model_data_memtracer
colossalai.
gemini
.memory_tracer.async_memtracer
colossalai.
gemini
.memory_tracer.memstats_collector
colossalai.
gemini
.memory_tracer.model_data_memtracer
docs/colossalai/colossalai.utils.rst
View file @
4d9332b4
...
...
@@ -9,7 +9,7 @@ colossalai.utils
colossalai.utils.data_sampler
colossalai.utils.gradient_accumulation
colossalai.
utils
.memory_tracer
colossalai.
gemini
.memory_tracer
colossalai.utils.memory_utils
colossalai.utils.multi_tensor_apply
colossalai.utils.profiler
...
...
tests/test_data/test_deterministic_dataloader.py
View file @
4d9332b4
...
...
@@ -78,6 +78,7 @@ def run_data_sampler(rank, world_size, port):
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
cpu
@
rerun_if_address_is_in_use
()
def
test_data_sampler
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment