Unverified Commit 0974c9bc authored by Yuan Tang's avatar Yuan Tang Committed by GitHub
Browse files

[Bugfix] Fix incorrect types in LayerwiseProfileResults (#12196)


Signed-off-by: default avatarYuan Tang <terrytangyuan@gmail.com>
parent d2643128
import copy import copy
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, TypeAlias, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
import pandas as pd import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
...@@ -128,7 +128,7 @@ class LayerwiseProfileResults(profile): ...@@ -128,7 +128,7 @@ class LayerwiseProfileResults(profile):
]) ])
df.to_csv(filename) df.to_csv(filename)
def convert_stats_to_dict(self) -> str: def convert_stats_to_dict(self) -> dict[str, Any]:
return { return {
"metadata": { "metadata": {
"num_running_seqs": self.num_running_seqs "num_running_seqs": self.num_running_seqs
...@@ -227,7 +227,7 @@ class LayerwiseProfileResults(profile): ...@@ -227,7 +227,7 @@ class LayerwiseProfileResults(profile):
[self._cumulative_cuda_time(root) for root in self._module_tree]) [self._cumulative_cuda_time(root) for root in self._module_tree])
def _build_stats_trees(self): def _build_stats_trees(self):
summary_dict: Dict[str, self.StatsTreeNode] = {} summary_dict: Dict[str, _StatsTreeNode] = {}
total_cuda_time = self._total_cuda_time() total_cuda_time = self._total_cuda_time()
def pct_cuda_time(cuda_time_us): def pct_cuda_time(cuda_time_us):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment