Unverified Commit b862d89d authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[doc] improved docstring in the logging module (#861)

parent 8004c8e9
...@@ -6,22 +6,20 @@ from .logger import DistributedLogger ...@@ -6,22 +6,20 @@ from .logger import DistributedLogger
__all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers'] __all__ = ['get_dist_logger', 'DistributedLogger', 'disable_existing_loggers']
def get_dist_logger(name='colossalai'): def get_dist_logger(name: str = 'colossalai') -> DistributedLogger:
"""Get logger instance based on name. The DistributedLogger will create singleton instances, """Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name. which means that only one logger instance is created per name.
Args: Args:
name (str): name of the logger, name must be unique
:param name: name of the logger, name must be unique Returns:
:type name: str :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance.
:return: a distributed logger instance
:rtype: :class:`colossalai.logging.DistributedLogger`
""" """
return DistributedLogger.get_instance(name=name) return DistributedLogger.get_instance(name=name)
def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']): def disable_existing_loggers(include: Optional[List[str]] = None, exclude: List[str] = ['colossalai']) -> None:
"""Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai". """Set the level of existing loggers to `WARNING`. By default, it will "disable" all existing loggers except the logger named "colossalai".
Args: Args:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import colossalai import colossalai
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union, List
import inspect import inspect
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
...@@ -40,6 +40,7 @@ class DistributedLogger: ...@@ -40,6 +40,7 @@ class DistributedLogger:
Args: Args:
name (str): The name of the logger. name (str): The name of the logger.
Returns: Returns:
DistributedLogger: A DistributedLogger object DistributedLogger: A DistributedLogger object
""" """
...@@ -75,7 +76,7 @@ class DistributedLogger: ...@@ -75,7 +76,7 @@ class DistributedLogger:
def _check_valid_logging_level(level: str): def _check_valid_logging_level(level: str):
assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level' assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR'], 'found invalid logging level'
def set_level(self, level: str): def set_level(self, level: str) -> None:
"""Set the logging level """Set the logging level
Args: Args:
...@@ -84,7 +85,7 @@ class DistributedLogger: ...@@ -84,7 +85,7 @@ class DistributedLogger:
self._check_valid_logging_level(level) self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level)) self._logger.setLevel(getattr(logging, level))
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None): def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None) -> None:
"""Save the logs to file """Save the logs to file
Args: Args:
...@@ -122,7 +123,11 @@ class DistributedLogger: ...@@ -122,7 +123,11 @@ class DistributedLogger:
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
self._logger.addHandler(file_handler) self._logger.addHandler(file_handler)
def _log(self, level, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def _log(self,
level,
message: str,
parallel_mode: ParallelMode = ParallelMode.GLOBAL,
ranks: List[int] = None) -> None:
if ranks is None: if ranks is None:
getattr(self._logger, level)(message) getattr(self._logger, level)(message)
else: else:
...@@ -130,53 +135,53 @@ class DistributedLogger: ...@@ -130,53 +135,53 @@ class DistributedLogger:
if local_rank in ranks: if local_rank in ranks:
getattr(self._logger, level)(message) getattr(self._logger, level)(message)
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an info message. """Log an info message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks) self._log('info', message_prefix, parallel_mode, ranks)
self._log('info', message, parallel_mode, ranks) self._log('info', message, parallel_mode, ranks)
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a warning message. """Log a warning message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks) self._log('warning', message_prefix, parallel_mode, ranks)
self._log('warning', message, parallel_mode, ranks) self._log('warning', message, parallel_mode, ranks)
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log a debug message. """Log a debug message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks) self._log('debug', message_prefix, parallel_mode, ranks)
self._log('debug', message, parallel_mode, ranks) self._log('debug', message, parallel_mode, ranks)
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: List[int] = None) -> None:
"""Log an error message. """Log an error message.
Args: Args:
message (str): The message to be logged. message (str): The message to be logged.
parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
The parallel mode used for logging. Defaults to ParallelMode.GLOBAL. The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
ranks (List): List of parallel ranks. ranks (List[int]): List of parallel ranks.
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks) self._log('error', message_prefix, parallel_mode, ranks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment