Commit cfb41297 authored by yuxuan-lou's avatar yuxuan-lou Committed by binmakeswell
Browse files

'fix/format' (#573)

parent b0f708df
...@@ -106,7 +106,7 @@ class MemTracerOpHook(BaseOpHook): ...@@ -106,7 +106,7 @@ class MemTracerOpHook(BaseOpHook):
# output file info # output file info
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl") self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
home_dir = Path.home() home_dir = Path.home()
with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f: with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
pickle.dump(self.async_mem_monitor.state_dict, f) pickle.dump(self.async_mem_monitor.state_dict, f)
self._count += 1 self._count += 1
self._logger.debug(f"data file has been refreshed {self._count} times") self._logger.debug(f"data file has been refreshed {self._count} times")
...@@ -115,4 +115,4 @@ class MemTracerOpHook(BaseOpHook): ...@@ -115,4 +115,4 @@ class MemTracerOpHook(BaseOpHook):
def save_results(self, data_file: Union[str, Path]): def save_results(self, data_file: Union[str, Path]):
with open(data_file, "w") as f: with open(data_file, "w") as f:
f.write(json.dumps(self.async_mem_monitor.state_dict)) f.write(json.dumps(self.async_mem_monitor.state_dict))
\ No newline at end of file
...@@ -85,8 +85,7 @@ class BaseSchedule(ABC): ...@@ -85,8 +85,7 @@ class BaseSchedule(ABC):
data_iter: Iterable, data_iter: Iterable,
forward_only: bool, forward_only: bool,
return_loss: bool = True, return_loss: bool = True,
return_output_label: bool = True return_output_label: bool = True):
):
"""The process function over a batch of dataset for training or evaluation. """The process function over a batch of dataset for training or evaluation.
Args: Args:
...@@ -107,8 +106,9 @@ class BaseSchedule(ABC): ...@@ -107,8 +106,9 @@ class BaseSchedule(ABC):
@staticmethod @staticmethod
def _call_engine_criterion(engine, outputs, labels): def _call_engine_criterion(engine, outputs, labels):
assert isinstance(outputs, (torch.Tensor, list, tuple) assert isinstance(
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' outputs,
(torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
if isinstance(labels, torch.Tensor): if isinstance(labels, torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment