Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bcab2495
Unverified
Commit
bcab2495
authored
Jun 07, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 07, 2022
Browse files
fix issue #1080 (#1071)
parent
1b178593
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
12 deletions
+28
-12
colossalai/trainer/hooks/_log_hook.py
colossalai/trainer/hooks/_log_hook.py
+9
-4
colossalai/trainer/hooks/_metric_hook.py
colossalai/trainer/hooks/_metric_hook.py
+19
-8
No files found.
colossalai/trainer/hooks/_log_hook.py
View file @
bcab2495
...
...
@@ -4,9 +4,7 @@
import
os
import
os.path
as
osp
import
torch
from
typing
import
List
from
decimal
import
Decimal
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
HOOKS
...
...
@@ -15,6 +13,7 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
from
colossalai.trainer.hooks._metric_hook
import
ThroughputMetric
class
LogByEpochHook
(
BaseHook
):
...
...
@@ -53,11 +52,17 @@ class LogMetricByStepHook(BaseHook):
def
after_train_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'train'
].
items
():
if
isinstance
(
metric_calculator
,
ThroughputMetric
):
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_info
()
else
:
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
def
after_test_iter
(
self
,
trainer
,
*
args
):
trainer
.
states
[
'step_metrics'
]
=
dict
()
for
metric_name
,
metric_calculator
in
trainer
.
states
[
'metrics'
][
'test'
].
items
():
if
isinstance
(
metric_calculator
,
ThroughputMetric
):
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_info
()
else
:
trainer
.
states
[
'step_metrics'
][
metric_name
.
lower
()]
=
metric_calculator
.
get_last_step_value
()
...
...
colossalai/trainer/hooks/_metric_hook.py
View file @
bcab2495
...
...
@@ -52,7 +52,7 @@ class Metric(ABC):
pass
@
abstractmethod
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
"""Returns the metric value in the last iteration.
"""
pass
...
...
@@ -121,10 +121,10 @@ class LossMetric(Metric):
self
.
accum_loss
.
div_
(
self
.
count
)
return
self
.
accum_loss
.
item
()
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
"""Returns :attr:`last_step_loss`.
"""
return
str
(
self
.
last_step_loss
.
cpu
().
item
()
)
return
self
.
last_step_loss
.
cpu
().
item
()
@
staticmethod
def
is_better
(
a
,
b
):
...
...
@@ -149,8 +149,8 @@ class LearningRateMetric(Metric):
def
update
(
self
,
lr
)
->
None
:
self
.
lr
=
lr
def
get_last_step_value
(
self
)
->
str
:
return
str
(
self
.
lr
)
def
get_last_step_value
(
self
)
->
float
:
return
self
.
lr
def
get_accumulated_value
(
self
):
return
self
.
lr
...
...
@@ -204,10 +204,10 @@ class AccuracyMetric(Metric):
self
.
accumulated_sum
+=
self
.
last_step_sum
self
.
accumulated_correct
+=
self
.
last_step_correct
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
self
.
last_step_sum
=
all_reduce
(
self
.
last_step_sum
,
ParallelMode
.
DATA
)
self
.
last_step_correct
=
all_reduce
(
self
.
last_step_correct
,
ParallelMode
.
DATA
)
return
str
(
_format_number
((
self
.
last_step_correct
/
self
.
last_step_sum
).
cpu
().
item
())
)
return
_format_number
((
self
.
last_step_correct
/
self
.
last_step_sum
).
cpu
().
item
())
def
get_accumulated_value
(
self
):
self
.
accumulated_sum
=
all_reduce
(
self
.
accumulated_sum
,
ParallelMode
.
DATA
)
...
...
@@ -350,7 +350,18 @@ class ThroughputMetric(Metric):
self
.
accumulated_num_samples
+=
self
.
last_step_num_samples
self
.
accumulated_used_time
+=
self
.
last_step_used_time
def
get_last_step_value
(
self
)
->
str
:
def
get_last_step_value
(
self
)
->
float
:
if
self
.
_use_local
:
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
return
sample_per_sec
def
get_last_step_info
(
self
)
->
str
:
if
self
.
_use_local
:
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment