Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bcab2495
Unverified
Commit
bcab2495
authored
Jun 07, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 07, 2022
Browse files
fix issue #1080 (#1071)
parent
1b178593
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
12 deletions
+28
-12
colossalai/trainer/hooks/_log_hook.py
colossalai/trainer/hooks/_log_hook.py
+9
-4
colossalai/trainer/hooks/_metric_hook.py
colossalai/trainer/hooks/_metric_hook.py
+19
-8
No files found.
colossalai/trainer/hooks/_log_hook.py
View file @
bcab2495
...
@@ -4,9 +4,7 @@
...
@@ -4,9 +4,7 @@
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
import
torch
from
typing
import
List
from
typing
import
List
from
decimal
import
Decimal
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
HOOKS
from
colossalai.registry
import
HOOKS
...
@@ -15,6 +13,7 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
...
@@ -15,6 +13,7 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
from
._base_hook
import
BaseHook
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
from
._commons_
import
_format_number
from
colossalai.trainer.hooks._metric_hook
import
ThroughputMetric
class
LogByEpochHook
(
BaseHook
):
class
LogByEpochHook
(
BaseHook
):
...
@@ -53,12 +52,18 @@ class LogMetricByStepHook(BaseHook):
...
@@ -53,12 +52,18 @@ class LogMetricByStepHook(BaseHook):
def
after_train_iter
(
self
,
trainer
,
*
args
):
def
after_train_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'train'
].
items
():
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'train'
].
items
():
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
if
isinstance
(
metric_calculator
,
ThroughputMetric
):
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_info
()
else
:
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
def
after_test_iter
(
self
,
trainer
,
*
args
):
def
after_test_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'test'
].
items
():
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'test'
].
items
():
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
if
isinstance
(
metric_calculator
,
ThroughputMetric
):
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_info
()
else
:
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
@
HOOKS
.
register_module
@
HOOKS
.
register_module
...
...
colossalai/trainer/hooks/_metric_hook.py
View file @
bcab2495
...
@@ -52,7 +52,7 @@ class Metric(ABC):
...
@@ -52,7 +52,7 @@ class Metric(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
"""Returns the metric value in the last iteration.
"""Returns the metric value in the last iteration.
"""
"""
pass
pass
...
@@ -121,10 +121,10 @@ class LossMetric(Metric):
...
@@ -121,10 +121,10 @@ class LossMetric(Metric):
self
.
accum_loss
.
div_
(
self
.
count
)
self
.
accum_loss
.
div_
(
self
.
count
)
return
self
.
accum_loss
.
item
()
return
self
.
accum_loss
.
item
()
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
"""Returns :attr:`last_step_loss`.
"""Returns :attr:`last_step_loss`.
"""
"""
return
str
(
self
.
last_step_loss
.
cpu
().
item
()
)
return
self
.
last_step_loss
.
cpu
().
item
()
@
staticmethod
@
staticmethod
def
is_better
(
a
,
b
):
def
is_better
(
a
,
b
):
...
@@ -149,8 +149,8 @@ class LearningRateMetric(Metric):
...
@@ -149,8 +149,8 @@ class LearningRateMetric(Metric):
def
update
(
self
,
lr
)
->
None
:
def
update
(
self
,
lr
)
->
None
:
self
.
lr
=
lr
self
.
lr
=
lr
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
return
str
(
self
.
lr
)
return
self
.
lr
def
get_accumulated_value
(
self
):
def
get_accumulated_value
(
self
):
return
self
.
lr
return
self
.
lr
...
@@ -204,10 +204,10 @@ class AccuracyMetric(Metric):
...
@@ -204,10 +204,10 @@ class AccuracyMetric(Metric):
self
.
accumulated_sum
+=
self
.
last_step_sum
self
.
accumulated_sum
+=
self
.
last_step_sum
self
.
accumulated_correct
+=
self
.
last_step_correct
self
.
accumulated_correct
+=
self
.
last_step_correct
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
self
.
last_step_sum
=
all_reduce
(
self
.
last_step_sum
,
ParallelMode
.
DATA
)
self
.
last_step_sum
=
all_reduce
(
self
.
last_step_sum
,
ParallelMode
.
DATA
)
self
.
last_step_correct
=
all_reduce
(
self
.
last_step_correct
,
ParallelMode
.
DATA
)
self
.
last_step_correct
=
all_reduce
(
self
.
last_step_correct
,
ParallelMode
.
DATA
)
return
str
(
_format_number
((
self
.
last_step_correct
/
self
.
last_step_sum
).
cpu
().
item
())
)
return
_format_number
((
self
.
last_step_correct
/
self
.
last_step_sum
).
cpu
().
item
())
def
get_accumulated_value
(
self
):
def
get_accumulated_value
(
self
):
self
.
accumulated_sum
=
all_reduce
(
self
.
accumulated_sum
,
ParallelMode
.
DATA
)
self
.
accumulated_sum
=
all_reduce
(
self
.
accumulated_sum
,
ParallelMode
.
DATA
)
...
@@ -350,7 +350,18 @@ class ThroughputMetric(Metric):
...
@@ -350,7 +350,18 @@ class ThroughputMetric(Metric):
self
.
accumulated_num_samples
+=
self
.
last_step_num_samples
self
.
accumulated_num_samples
+=
self
.
last_step_num_samples
self
.
accumulated_used_time
+=
self
.
last_step_used_time
self
.
accumulated_used_time
+=
self
.
last_step_used_time
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
if
self
.
_use_local
:
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
return
sample_per_sec
def
get_last_step_info
(
self
)
->
str
:
if
self
.
_use_local
:
if
self
.
_use_local
:
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment