Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d9332b4
Unverified
Commit
4d9332b4
authored
Apr 19, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 19, 2022
Browse files
[refactor] moving memtracer to gemini (#801)
parent
8711c706
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
98 additions
and
81 deletions
+98
-81
colossalai/engine/ophooks/_memtracer_ophook.py
colossalai/engine/ophooks/_memtracer_ophook.py
+1
-2
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+9
-7
colossalai/gemini/memory_tracer/__init__.py
colossalai/gemini/memory_tracer/__init__.py
+2
-1
colossalai/gemini/memory_tracer/memory_monitor.py
colossalai/gemini/memory_tracer/memory_monitor.py
+1
-1
colossalai/gemini/memory_tracer/memstats_collector.py
colossalai/gemini/memory_tracer/memstats_collector.py
+6
-2
colossalai/gemini/memory_tracer/model_data_memtracer.py
colossalai/gemini/memory_tracer/model_data_memtracer.py
+0
-0
colossalai/gemini/tensor_placement_policy.py
colossalai/gemini/tensor_placement_policy.py
+2
-2
colossalai/trainer/hooks/_commons_.py
colossalai/trainer/hooks/_commons_.py
+9
-0
colossalai/trainer/hooks/_log_hook.py
colossalai/trainer/hooks/_log_hook.py
+21
-35
colossalai/trainer/hooks/_mem_tracer_hook.py
colossalai/trainer/hooks/_mem_tracer_hook.py
+1
-1
colossalai/trainer/hooks/_metric_hook.py
colossalai/trainer/hooks/_metric_hook.py
+30
-16
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-2
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+4
-3
docs/colossalai/colossalai.utils.memory_tracer.async_memtracer.rst
...ssalai/colossalai.utils.memory_tracer.async_memtracer.rst
+1
-1
docs/colossalai/colossalai.utils.memory_tracer.memstats_collector.rst
...lai/colossalai.utils.memory_tracer.memstats_collector.rst
+1
-1
docs/colossalai/colossalai.utils.memory_tracer.model_data_memtracer.rst
...i/colossalai.utils.memory_tracer.model_data_memtracer.rst
+1
-1
docs/colossalai/colossalai.utils.memory_tracer.rst
docs/colossalai/colossalai.utils.memory_tracer.rst
+4
-4
docs/colossalai/colossalai.utils.rst
docs/colossalai/colossalai.utils.rst
+1
-1
tests/test_data/test_deterministic_dataloader.py
tests/test_data/test_deterministic_dataloader.py
+1
-0
No files found.
colossalai/engine/ophooks/_memtracer_ophook.py
View file @
4d9332b4
...
@@ -8,8 +8,6 @@ from colossalai.registry import OPHOOKS
...
@@ -8,8 +8,6 @@ from colossalai.registry import OPHOOKS
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
typing
import
Union
from
typing
import
Union
from
colossalai.utils.memory_tracer
import
AsyncMemoryMonitor
import
os
import
math
import
math
...
@@ -25,6 +23,7 @@ class MemTracerOpHook(BaseOpHook):
...
@@ -25,6 +23,7 @@ class MemTracerOpHook(BaseOpHook):
"""
"""
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
def
__init__
(
self
,
warmup
:
int
=
50
,
refreshrate
:
int
=
10
,
data_prefix
:
str
=
"memstats"
):
from
colossalai.gemini.memory_tracer
import
AsyncMemoryMonitor
super
().
__init__
()
super
().
__init__
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
self
.
async_mem_monitor
=
AsyncMemoryMonitor
()
self
.
_curiter
=
0
self
.
_curiter
=
0
...
...
colossalai/engine/schedule/_pipeline_schedule.py
View file @
4d9332b4
...
@@ -12,10 +12,10 @@ from colossalai.core import global_context as gpc
...
@@ -12,10 +12,10 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
._base_schedule
import
BaseSchedule
from
._base_schedule
import
BaseSchedule
def
get_tensor_shape
():
def
get_tensor_shape
():
if
hasattr
(
gpc
.
config
,
'TENSOR_SHAPE'
):
if
hasattr
(
gpc
.
config
,
'TENSOR_SHAPE'
):
return
gpc
.
config
.
TENSOR_SHAPE
return
gpc
.
config
.
TENSOR_SHAPE
...
@@ -23,7 +23,8 @@ def get_tensor_shape():
...
@@ -23,7 +23,8 @@ def get_tensor_shape():
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
):
if
not
gpc
.
is_initialized
(
ParallelMode
.
PIPELINE
):
return
None
return
None
if
hasattr
(
gpc
.
config
,
'SEQ_LENGTH'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'HIDDEN_SIZE'
):
if
hasattr
(
gpc
.
config
,
'SEQ_LENGTH'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'GLOBAL_BATCH_SIZE'
)
and
hasattr
(
gpc
.
config
,
'HIDDEN_SIZE'
):
if
gpc
.
is_initialized
(
ParallelMode
.
DATA
):
if
gpc
.
is_initialized
(
ParallelMode
.
DATA
):
dp_size
=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
dp_size
=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
else
:
...
@@ -34,12 +35,12 @@ def get_tensor_shape():
...
@@ -34,12 +35,12 @@ def get_tensor_shape():
seq_size
=
1
seq_size
=
1
tensor_shape
=
(
gpc
.
config
.
SEQ_LENGTH
//
seq_size
,
tensor_shape
=
(
gpc
.
config
.
SEQ_LENGTH
//
seq_size
,
gpc
.
config
.
GLOBAL_BATCH_SIZE
//
dp_size
//
gpc
.
config
.
NUM_MICRO_BATCHES
,
gpc
.
config
.
GLOBAL_BATCH_SIZE
//
dp_size
//
gpc
.
config
.
NUM_MICRO_BATCHES
,
gpc
.
config
.
HIDDEN_SIZE
)
gpc
.
config
.
HIDDEN_SIZE
)
return
tensor_shape
return
tensor_shape
else
:
else
:
return
None
return
None
def
pack_return_tensors
(
return_tensors
):
def
pack_return_tensors
(
return_tensors
):
output
,
label
=
tuple
(
zip
(
*
return_tensors
))
output
,
label
=
tuple
(
zip
(
*
return_tensors
))
if
isinstance
(
output
[
0
],
torch
.
Tensor
):
if
isinstance
(
output
[
0
],
torch
.
Tensor
):
...
@@ -114,7 +115,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -114,7 +115,7 @@ class PipelineSchedule(BaseSchedule):
def
pre_processing
(
self
,
engine
):
def
pre_processing
(
self
,
engine
):
# TODO: remove this after testing new zero with pipeline parallelism
# TODO: remove this after testing new zero with pipeline parallelism
model
=
engine
.
model
model
=
engine
.
model
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)
):
if
isinstance
(
model
,
(
NaiveAMPModel
))
or
hasattr
(
model
,
'colo_attr'
):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
model
=
model
.
model
model
=
model
.
model
sig
=
inspect
.
signature
(
model
.
forward
)
sig
=
inspect
.
signature
(
model
.
forward
)
...
@@ -125,7 +126,7 @@ class PipelineSchedule(BaseSchedule):
...
@@ -125,7 +126,7 @@ class PipelineSchedule(BaseSchedule):
def
_call_engine
(
model
,
input_tensor
,
batch_data
):
def
_call_engine
(
model
,
input_tensor
,
batch_data
):
if
isinstance
(
model
,
NaiveAMPModel
):
if
isinstance
(
model
,
NaiveAMPModel
):
sig
=
inspect
.
signature
(
model
.
model
.
forward
)
sig
=
inspect
.
signature
(
model
.
model
.
forward
)
elif
isinstance
(
model
,
ShardedModelV2
):
elif
hasattr
(
model
,
'colo_attr'
):
sig
=
inspect
.
signature
(
model
.
module
.
forward
)
sig
=
inspect
.
signature
(
model
.
module
.
forward
)
else
:
else
:
sig
=
inspect
.
signature
(
model
.
forward
)
sig
=
inspect
.
signature
(
model
.
forward
)
...
@@ -385,7 +386,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
...
@@ -385,7 +386,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
self
.
num_model_chunks
=
num_model_chunks
self
.
num_model_chunks
=
num_model_chunks
def
pre_processing
(
self
,
engine
):
def
pre_processing
(
self
,
engine
):
if
isinstance
(
engine
.
model
,
ShardedModelV2
):
# FIXME(jiaruifang) we shall not use ShardedModelV2 in pipeline mode, due to circular dependency.
if
hasattr
(
engine
.
model
,
'colo_attr'
):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
elif
isinstance
(
engine
.
model
[
0
],
NaiveAMPModel
):
elif
isinstance
(
engine
.
model
[
0
],
NaiveAMPModel
):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
...
...
colossalai/
utils
/memory_tracer/__init__.py
→
colossalai/
gemini
/memory_tracer/__init__.py
View file @
4d9332b4
from
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memory_monitor
import
AsyncMemoryMonitor
,
SyncCudaMemoryMonitor
from
.memstats_collector
import
MemStatsCollector
from
.memstats_collector
import
MemStatsCollector
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
]
__all__
=
[
'AsyncMemoryMonitor'
,
'SyncCudaMemoryMonitor'
,
'MemStatsCollector'
,
'GLOBAL_MODEL_DATA_TRACER'
]
colossalai/
utils
/memory_tracer/memory_monitor.py
→
colossalai/
gemini
/memory_tracer/memory_monitor.py
View file @
4d9332b4
...
@@ -5,7 +5,7 @@ import json
...
@@ -5,7 +5,7 @@ import json
import
torch
import
torch
from
colossalai.utils
.memory
import
colo_device_memory_used
from
colossalai.utils
import
colo_device_memory_used
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
...
...
colossalai/
utils
/memory_tracer/memstats_collector.py
→
colossalai/
gemini
/memory_tracer/memstats_collector.py
View file @
4d9332b4
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.gemini.memory_tracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.gemini.memory_tracer
import
SyncCudaMemoryMonitor
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory
import
colo_device_memory_used
from
colossalai.utils.memory_tracer
import
SyncCudaMemoryMonitor
import
torch
import
torch
import
time
import
time
from
typing
import
List
from
typing
import
List
...
@@ -138,6 +139,9 @@ class MemStatsCollector:
...
@@ -138,6 +139,9 @@ class MemStatsCollector:
self
.
_model_data_cpu_list
=
[]
self
.
_model_data_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_overall_cpu_list
=
[]
self
.
_non_model_data_cpu_list
=
[]
self
.
_non_model_data_cuda_list
=
[]
self
.
_start_flag
=
False
self
.
_start_flag
=
False
self
.
_step_idx
=
0
self
.
_step_idx
=
0
self
.
_step_total
=
0
self
.
_step_total
=
0
colossalai/
utils
/memory_tracer/model_data_memtracer.py
→
colossalai/
gemini
/memory_tracer/model_data_memtracer.py
View file @
4d9332b4
File moved
colossalai/gemini/tensor_placement_policy.py
View file @
4d9332b4
...
@@ -5,8 +5,8 @@ from colossalai.utils import get_current_device
...
@@ -5,8 +5,8 @@ from colossalai.utils import get_current_device
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
from
colossalai.
utils
.memory_tracer
import
MemStatsCollector
from
colossalai.
gemini
.memory_tracer
import
MemStatsCollector
from
colossalai.
utils
.memory_tracer
.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.
gemini
.memory_tracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Type
from
typing
import
Type
...
...
colossalai/trainer/hooks/_commons_.py
0 → 100644
View file @
4d9332b4
import
torch
def
_format_number
(
val
,
prec
=
5
):
if
isinstance
(
val
,
float
):
return
f
'
{
val
:.
{
prec
}
g
}
'
elif
torch
.
is_tensor
(
val
)
and
torch
.
is_floating_point
(
val
):
return
f
'
{
val
.
item
():.
{
prec
}
g
}
'
return
val
colossalai/trainer/hooks/_log_hook.py
View file @
4d9332b4
...
@@ -14,14 +14,7 @@ from colossalai.logging import DistributedLogger
...
@@ -14,14 +14,7 @@ from colossalai.logging import DistributedLogger
from
colossalai.utils
import
report_memory_usage
,
is_dp_rank_0
,
\
from
colossalai.utils
import
report_memory_usage
,
is_dp_rank_0
,
\
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
from
._base_hook
import
BaseHook
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
def
_format_number
(
val
,
prec
=
5
):
if
isinstance
(
val
,
float
):
return
f
'
{
val
:.
{
prec
}
g
}
'
elif
torch
.
is_tensor
(
val
)
and
torch
.
is_floating_point
(
val
):
return
f
'
{
val
.
item
():.
{
prec
}
g
}
'
return
val
class
LogByEpochHook
(
BaseHook
):
class
LogByEpochHook
(
BaseHook
):
...
@@ -35,10 +28,7 @@ class LogByEpochHook(BaseHook):
...
@@ -35,10 +28,7 @@ class LogByEpochHook(BaseHook):
depend on the hooks order in the hook list.
depend on the hooks order in the hook list.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
logger
,
interval
:
int
=
1
,
priority
:
int
=
1
):
logger
,
interval
:
int
=
1
,
priority
:
int
=
1
):
super
().
__init__
(
priority
)
super
().
__init__
(
priority
)
self
.
logger
=
logger
self
.
logger
=
logger
self
.
_interval
=
interval
self
.
_interval
=
interval
...
@@ -63,14 +53,12 @@ class LogMetricByStepHook(BaseHook):
...
@@ -63,14 +53,12 @@ class LogMetricByStepHook(BaseHook):
def
after_train_iter
(
self
,
trainer
,
*
args
):
def
after_train_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'train'
].
items
():
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'train'
].
items
():
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
\
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
f
'
{
_format_number
(
metric_calculator
.
get_last_step_value
())
}
'
def
after_test_iter
(
self
,
trainer
,
*
args
):
def
after_test_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'test'
].
items
():
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'test'
].
items
():
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
\
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
f
'
{
_format_number
(
metric_calculator
.
get_last_step_value
())
}
'
@
HOOKS
.
register_module
@
HOOKS
.
register_module
...
@@ -85,18 +73,14 @@ class LogMetricByEpochHook(LogByEpochHook):
...
@@ -85,18 +73,14 @@ class LogMetricByEpochHook(LogByEpochHook):
depend on the hooks order in the hook list.
depend on the hooks order in the hook list.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
logger
,
interval
:
int
=
1
,
priority
:
int
=
10
)
->
None
:
logger
,
interval
:
int
=
1
,
priority
:
int
=
10
)
->
None
:
super
().
__init__
(
logger
,
interval
,
priority
)
super
().
__init__
(
logger
,
interval
,
priority
)
self
.
_is_rank_to_log
=
is_dp_rank_0
()
and
is_tp_rank_0
()
and
is_no_pp_or_last_stage
()
self
.
_is_rank_to_log
=
is_dp_rank_0
()
and
is_tp_rank_0
()
and
is_no_pp_or_last_stage
()
def
_get_str
(
self
,
trainer
,
mode
):
def
_get_str
(
self
,
trainer
,
mode
):
msg
=
[]
msg
=
[]
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
mode
].
items
():
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
mode
].
items
():
msg
.
append
(
msg
.
append
(
f
'
{
metric_name
}
=
{
_format_number
(
metric_calculator
.
get_accumulated_value
())
}
'
)
f
'
{
metric_name
}
=
{
_format_number
(
metric_calculator
.
get_accumulated_value
())
}
'
)
msg
=
' | '
.
join
(
msg
)
msg
=
' | '
.
join
(
msg
)
return
msg
return
msg
...
@@ -130,12 +114,13 @@ class TensorboardHook(BaseHook):
...
@@ -130,12 +114,13 @@ class TensorboardHook(BaseHook):
depend on the hooks order in the hook list.
depend on the hooks order in the hook list.
"""
"""
def
__init__
(
self
,
def
__init__
(
log_dir
:
str
,
self
,
ranks
:
List
=
None
,
log_dir
:
str
,
parallel_mode
:
ParallelMode
=
ParallelMode
.
GLOBAL
,
ranks
:
List
=
None
,
priority
:
int
=
10
,
parallel_mode
:
ParallelMode
=
ParallelMode
.
GLOBAL
,
)
->
None
:
priority
:
int
=
10
,
)
->
None
:
super
().
__init__
(
priority
=
priority
)
super
().
__init__
(
priority
=
priority
)
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
...
@@ -280,13 +265,14 @@ class LogMemoryByEpochHook(LogByEpochHook):
...
@@ -280,13 +265,14 @@ class LogMemoryByEpochHook(LogByEpochHook):
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
log_eval (bool, optional): Whether writes in evaluation, defaults to True.
"""
"""
def
__init__
(
self
,
def
__init__
(
logger
:
DistributedLogger
,
self
,
interval
:
int
=
1
,
logger
:
DistributedLogger
,
priority
:
int
=
10
,
interval
:
int
=
1
,
log_eval
:
bool
=
True
,
priority
:
int
=
10
,
report_cpu
:
bool
=
False
,
# no reference
log_eval
:
bool
=
True
,
)
->
None
:
report_cpu
:
bool
=
False
,
# no reference
)
->
None
:
super
().
__init__
(
logger
=
logger
,
interval
=
interval
,
priority
=
priority
)
super
().
__init__
(
logger
=
logger
,
interval
=
interval
,
priority
=
priority
)
self
.
_log_eval
=
log_eval
self
.
_log_eval
=
log_eval
self
.
_is_rank_to_log
=
is_dp_rank_0
()
and
is_tp_rank_0
()
self
.
_is_rank_to_log
=
is_dp_rank_0
()
and
is_tp_rank_0
()
...
...
colossalai/trainer/hooks/_mem_tracer_hook.py
View file @
4d9332b4
from
colossalai.registry
import
HOOKS
from
colossalai.registry
import
HOOKS
from
torch
import
Tensor
from
torch
import
Tensor
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.
utils
.memory_tracer
import
AsyncMemoryMonitor
from
colossalai.
gemini
.memory_tracer
import
AsyncMemoryMonitor
@
HOOKS
.
register_module
@
HOOKS
.
register_module
...
...
colossalai/trainer/hooks/_metric_hook.py
View file @
4d9332b4
...
@@ -13,6 +13,7 @@ from colossalai.registry import HOOKS
...
@@ -13,6 +13,7 @@ from colossalai.registry import HOOKS
from
colossalai.utils
import
get_current_device
,
is_no_pp_or_last_stage
from
colossalai.utils
import
get_current_device
,
is_no_pp_or_last_stage
from
._base_hook
import
BaseHook
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
class
Metric
(
ABC
):
class
Metric
(
ABC
):
...
@@ -51,7 +52,7 @@ class Metric(ABC):
...
@@ -51,7 +52,7 @@ class Metric(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
"""Returns the metric value in the last iteration.
"""Returns the metric value in the last iteration.
"""
"""
pass
pass
...
@@ -120,10 +121,10 @@ class LossMetric(Metric):
...
@@ -120,10 +121,10 @@ class LossMetric(Metric):
self
.
accum_loss
.
div_
(
self
.
count
)
self
.
accum_loss
.
div_
(
self
.
count
)
return
self
.
accum_loss
.
item
()
return
self
.
accum_loss
.
item
()
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
"""Returns :attr:`last_step_loss`.
"""Returns :attr:`last_step_loss`.
"""
"""
return
self
.
last_step_loss
return
str
(
self
.
last_step_loss
)
@
staticmethod
@
staticmethod
def
is_better
(
a
,
b
):
def
is_better
(
a
,
b
):
...
@@ -148,8 +149,8 @@ class LearningRateMetric(Metric):
...
@@ -148,8 +149,8 @@ class LearningRateMetric(Metric):
def
update
(
self
,
lr
)
->
None
:
def
update
(
self
,
lr
)
->
None
:
self
.
lr
=
lr
self
.
lr
=
lr
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
return
self
.
lr
return
str
(
self
.
lr
)
def
get_accumulated_value
(
self
):
def
get_accumulated_value
(
self
):
return
self
.
lr
return
self
.
lr
...
@@ -203,10 +204,10 @@ class AccuracyMetric(Metric):
...
@@ -203,10 +204,10 @@ class AccuracyMetric(Metric):
self
.
accumulated_sum
+=
self
.
last_step_sum
self
.
accumulated_sum
+=
self
.
last_step_sum
self
.
accumulated_correct
+=
self
.
last_step_correct
self
.
accumulated_correct
+=
self
.
last_step_correct
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
self
.
last_step_sum
=
all_reduce
(
self
.
last_step_sum
,
ParallelMode
.
DATA
)
self
.
last_step_sum
=
all_reduce
(
self
.
last_step_sum
,
ParallelMode
.
DATA
)
self
.
last_step_correct
=
all_reduce
(
self
.
last_step_correct
,
ParallelMode
.
DATA
)
self
.
last_step_correct
=
all_reduce
(
self
.
last_step_correct
,
ParallelMode
.
DATA
)
return
(
self
.
last_step_correct
/
self
.
last_step_sum
).
item
()
return
str
(
_format_number
(
(
self
.
last_step_correct
/
self
.
last_step_sum
).
item
()
))
def
get_accumulated_value
(
self
):
def
get_accumulated_value
(
self
):
self
.
accumulated_sum
=
all_reduce
(
self
.
accumulated_sum
,
ParallelMode
.
DATA
)
self
.
accumulated_sum
=
all_reduce
(
self
.
accumulated_sum
,
ParallelMode
.
DATA
)
...
@@ -322,7 +323,8 @@ class ThroughputMetric(Metric):
...
@@ -322,7 +323,8 @@ class ThroughputMetric(Metric):
Args:
Args:
epoch_only (bool): Whether the metric only read for the full epoch.
epoch_only (bool): Whether the metric only read for the full epoch.
"""
"""
def
__init__
(
self
,
epoch_only
:
bool
,
ignored_steps
:
int
=
0
):
def
__init__
(
self
,
epoch_only
:
bool
,
ignored_steps
:
int
=
0
,
tflop_per_step
:
int
=
0
):
super
().
__init__
(
epoch_only
=
epoch_only
)
super
().
__init__
(
epoch_only
=
epoch_only
)
self
.
ignored_steps
=
ignored_steps
self
.
ignored_steps
=
ignored_steps
self
.
cur_steps
=
0
self
.
cur_steps
=
0
...
@@ -330,6 +332,7 @@ class ThroughputMetric(Metric):
...
@@ -330,6 +332,7 @@ class ThroughputMetric(Metric):
self
.
accumulated_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
accumulated_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_num_samples
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_num_samples
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
_tflop_per_step
=
tflop_per_step
def
reset
(
self
)
->
None
:
def
reset
(
self
)
->
None
:
# self.cur_steps = 0
# self.cur_steps = 0
...
@@ -346,13 +349,18 @@ class ThroughputMetric(Metric):
...
@@ -346,13 +349,18 @@ class ThroughputMetric(Metric):
self
.
accumulated_num_samples
+=
self
.
last_step_num_samples
self
.
accumulated_num_samples
+=
self
.
last_step_num_samples
self
.
accumulated_used_time
+=
self
.
last_step_used_time
self
.
accumulated_used_time
+=
self
.
last_step_used_time
def
get_last_step_value
(
self
):
def
get_last_step_value
(
self
)
->
str
:
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
return
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
)).
item
()
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
if
self
.
_tflop_per_step
>
0
:
def
get_accumulated_value
(
self
):
tflops
=
_format_number
(
self
.
_tflop_per_step
/
(
self
.
last_step_used_time
.
item
()
+
1e-12
))
return
f
"
{
sample_per_sec
}
sample_per_sec,
{
tflops
}
Tflops"
else
:
return
f
"
{
sample_per_sec
}
sample_per_sec"
def
get_accumulated_value
(
self
)
->
float
:
self
.
accumulated_used_time
=
all_reduce
(
self
.
accumulated_used_time
,
ParallelMode
.
DATA
)
/
\
self
.
accumulated_used_time
=
all_reduce
(
self
.
accumulated_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
accumulated_num_samples
=
all_reduce
(
self
.
accumulated_num_samples
,
ParallelMode
.
DATA
)
self
.
accumulated_num_samples
=
all_reduce
(
self
.
accumulated_num_samples
,
ParallelMode
.
DATA
)
...
@@ -373,14 +381,18 @@ class ThroughputHook(MetricHook):
...
@@ -373,14 +381,18 @@ class ThroughputHook(MetricHook):
defaults to 10. If different hooks share same priority, the order of printing would
defaults to 10. If different hooks share same priority, the order of printing would
depend on the hooks order in the hook list.
depend on the hooks order in the hook list.
"""
"""
def
__init__
(
self
,
ignored_steps
:
int
=
0
,
priority
:
int
=
10
):
def
__init__
(
self
,
ignored_steps
:
int
=
0
,
priority
:
int
=
10
,
tflop_per_step
:
int
=
0
):
super
().
__init__
(
priority
)
super
().
__init__
(
priority
)
self
.
ignored_steps
=
ignored_steps
self
.
ignored_steps
=
ignored_steps
self
.
_tflop_per_step
=
tflop_per_step
def
after_hook_is_attached
(
self
,
trainer
):
def
after_hook_is_attached
(
self
,
trainer
):
self
.
_check_metric_states_initialization
(
trainer
)
self
.
_check_metric_states_initialization
(
trainer
)
if
self
.
_is_stage_to_compute
:
if
self
.
_is_stage_to_compute
:
self
.
metric
=
ThroughputMetric
(
epoch_only
=
True
,
ignored_steps
=
self
.
ignored_steps
)
self
.
metric
=
ThroughputMetric
(
epoch_only
=
True
,
ignored_steps
=
self
.
ignored_steps
,
tflop_per_step
=
self
.
_tflop_per_step
)
# register the metric
# register the metric
trainer
.
states
[
'metrics'
][
'train'
][
'Throughput'
]
=
self
.
metric
trainer
.
states
[
'metrics'
][
'train'
][
'Throughput'
]
=
self
.
metric
...
@@ -392,7 +404,8 @@ class ThroughputHook(MetricHook):
...
@@ -392,7 +404,8 @@ class ThroughputHook(MetricHook):
def
after_train_iter
(
self
,
trainer
,
*
args
):
def
after_train_iter
(
self
,
trainer
,
*
args
):
if
self
.
_is_stage_to_compute
:
if
self
.
_is_stage_to_compute
:
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Train-step'
).
get_elapsed_time
())
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Train-step'
).
get_elapsed_time
())
def
before_test
(
self
,
trainer
):
def
before_test
(
self
,
trainer
):
if
self
.
_is_stage_to_compute
:
if
self
.
_is_stage_to_compute
:
...
@@ -400,4 +413,5 @@ class ThroughputHook(MetricHook):
...
@@ -400,4 +413,5 @@ class ThroughputHook(MetricHook):
def
after_test_iter
(
self
,
trainer
,
*
args
):
def
after_test_iter
(
self
,
trainer
,
*
args
):
if
self
.
_is_stage_to_compute
:
if
self
.
_is_stage_to_compute
:
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Test-step'
).
get_elapsed_time
())
self
.
metric
.
update
(
trainer
.
engine
.
schedule
.
batch_size
,
trainer
.
_timer
.
get_timer
(
'Test-step'
).
get_elapsed_time
())
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
4d9332b4
...
@@ -12,8 +12,8 @@ from colossalai.zero.utils import ZeroHook
...
@@ -12,8 +12,8 @@ from colossalai.zero.utils import ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.
utils
.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.
gemini
.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.
utils
.memory_tracer.model_data_memtracer
import
\
from
colossalai.
gemini
.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
4d9332b4
...
@@ -10,7 +10,7 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -10,7 +10,7 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.
utils
.memory_tracer.model_data_memtracer
import
\
from
colossalai.
gemini
.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.sharded_param.tensor_utils
import
(
colo_model_data_tensor_move_inline
,
colo_model_tensor_clone
,
from
colossalai.zero.sharded_param.tensor_utils
import
(
colo_model_data_tensor_move_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
colo_tensor_mem_usage
)
...
...
colossalai/zero/utils/zero_hook.py
View file @
4d9332b4
...
@@ -5,14 +5,15 @@ import torch.distributed as dist
...
@@ -5,14 +5,15 @@ import torch.distributed as dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.engine.ophooks
import
BaseOpHook
from
colossalai.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.gemini.memory_tracer
import
MemStatsCollector
from
typing
import
Any
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
class
ZeroHook
(
BaseOpHook
):
class
ZeroHook
(
BaseOpHook
):
...
...
docs/colossalai/colossalai.utils.memory_tracer.async_memtracer.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer.async\_memtracer
colossalai.utils.memory\_tracer.async\_memtracer
================================================
================================================
.. automodule:: colossalai.
utils
.memory_tracer.async_memtracer
.. automodule:: colossalai.
gemini
.memory_tracer.async_memtracer
:members:
:members:
docs/colossalai/colossalai.utils.memory_tracer.memstats_collector.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer.memstats\_collector
colossalai.utils.memory\_tracer.memstats\_collector
===================================================
===================================================
.. automodule:: colossalai.
utils
.memory_tracer.memstats_collector
.. automodule:: colossalai.
gemini
.memory_tracer.memstats_collector
:members:
:members:
docs/colossalai/colossalai.utils.memory_tracer.model_data_memtracer.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer.model\_data\_memtracer
colossalai.utils.memory\_tracer.model\_data\_memtracer
======================================================
======================================================
.. automodule:: colossalai.
utils
.memory_tracer.model_data_memtracer
.. automodule:: colossalai.
gemini
.memory_tracer.model_data_memtracer
:members:
:members:
docs/colossalai/colossalai.utils.memory_tracer.rst
View file @
4d9332b4
colossalai.utils.memory\_tracer
colossalai.utils.memory\_tracer
===============================
===============================
.. automodule:: colossalai.
utils
.memory_tracer
.. automodule:: colossalai.
gemini
.memory_tracer
:members:
:members:
.. toctree::
.. toctree::
:maxdepth: 2
:maxdepth: 2
colossalai.
utils
.memory_tracer.async_memtracer
colossalai.
gemini
.memory_tracer.async_memtracer
colossalai.
utils
.memory_tracer.memstats_collector
colossalai.
gemini
.memory_tracer.memstats_collector
colossalai.
utils
.memory_tracer.model_data_memtracer
colossalai.
gemini
.memory_tracer.model_data_memtracer
docs/colossalai/colossalai.utils.rst
View file @
4d9332b4
...
@@ -9,7 +9,7 @@ colossalai.utils
...
@@ -9,7 +9,7 @@ colossalai.utils
colossalai.utils.data_sampler
colossalai.utils.data_sampler
colossalai.utils.gradient_accumulation
colossalai.utils.gradient_accumulation
colossalai.
utils
.memory_tracer
colossalai.
gemini
.memory_tracer
colossalai.utils.memory_utils
colossalai.utils.memory_utils
colossalai.utils.multi_tensor_apply
colossalai.utils.multi_tensor_apply
colossalai.utils.profiler
colossalai.utils.profiler
...
...
tests/test_data/test_deterministic_dataloader.py
View file @
4d9332b4
...
@@ -78,6 +78,7 @@ def run_data_sampler(rank, world_size, port):
...
@@ -78,6 +78,7 @@ def run_data_sampler(rank, world_size, port):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
cpu
@
pytest
.
mark
.
cpu
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_data_sampler
():
def
test_data_sampler
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment