Unverified Commit bcab2495 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

fix issue #1080 (#1071)

parent 1b178593
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
import os import os
import os.path as osp import os.path as osp
import torch
from typing import List from typing import List
from decimal import Decimal
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
...@@ -15,6 +13,7 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \ ...@@ -15,6 +13,7 @@ from colossalai.utils import report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number from ._commons_ import _format_number
from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook): class LogByEpochHook(BaseHook):
...@@ -53,12 +52,18 @@ class LogMetricByStepHook(BaseHook): ...@@ -53,12 +52,18 @@ class LogMetricByStepHook(BaseHook):
def after_train_iter(self, trainer, *args): def after_train_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() if isinstance(metric_calculator, ThroughputMetric):
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
else:
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
def after_test_iter(self, trainer, *args): def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value() if isinstance(metric_calculator, ThroughputMetric):
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_info()
else:
trainer.states['step_metrics'][metric_name.lower()] = metric_calculator.get_last_step_value()
@HOOKS.register_module @HOOKS.register_module
......
...@@ -52,7 +52,7 @@ class Metric(ABC): ...@@ -52,7 +52,7 @@ class Metric(ABC):
pass pass
@abstractmethod @abstractmethod
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
"""Returns the metric value in the last iteration. """Returns the metric value in the last iteration.
""" """
pass pass
...@@ -121,10 +121,10 @@ class LossMetric(Metric): ...@@ -121,10 +121,10 @@ class LossMetric(Metric):
self.accum_loss.div_(self.count) self.accum_loss.div_(self.count)
return self.accum_loss.item() return self.accum_loss.item()
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
"""Returns :attr:`last_step_loss`. """Returns :attr:`last_step_loss`.
""" """
return str(self.last_step_loss.cpu().item()) return self.last_step_loss.cpu().item()
@staticmethod @staticmethod
def is_better(a, b): def is_better(a, b):
...@@ -149,8 +149,8 @@ class LearningRateMetric(Metric): ...@@ -149,8 +149,8 @@ class LearningRateMetric(Metric):
def update(self, lr) -> None: def update(self, lr) -> None:
self.lr = lr self.lr = lr
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
return str(self.lr) return self.lr
def get_accumulated_value(self): def get_accumulated_value(self):
return self.lr return self.lr
...@@ -204,10 +204,10 @@ class AccuracyMetric(Metric): ...@@ -204,10 +204,10 @@ class AccuracyMetric(Metric):
self.accumulated_sum += self.last_step_sum self.accumulated_sum += self.last_step_sum
self.accumulated_correct += self.last_step_correct self.accumulated_correct += self.last_step_correct
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA) self.last_step_sum = all_reduce(self.last_step_sum, ParallelMode.DATA)
self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA) self.last_step_correct = all_reduce(self.last_step_correct, ParallelMode.DATA)
return str(_format_number((self.last_step_correct / self.last_step_sum).cpu().item())) return _format_number((self.last_step_correct / self.last_step_sum).cpu().item())
def get_accumulated_value(self): def get_accumulated_value(self):
self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA) self.accumulated_sum = all_reduce(self.accumulated_sum, ParallelMode.DATA)
...@@ -350,7 +350,18 @@ class ThroughputMetric(Metric): ...@@ -350,7 +350,18 @@ class ThroughputMetric(Metric):
self.accumulated_num_samples += self.last_step_num_samples self.accumulated_num_samples += self.last_step_num_samples
self.accumulated_used_time += self.last_step_used_time self.accumulated_used_time += self.last_step_used_time
def get_last_step_value(self) -> str: def get_last_step_value(self) -> float:
if self._use_local:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
return sample_per_sec
def get_last_step_info(self) -> str:
if self._use_local: if self._use_local:
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment