"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "27ca23dc002e06eade014ac6b801dc2dcbea40f3"
Unverified Commit ff334ca1 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `vllm/profiler` (#18057)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6223dd81
...@@ -84,7 +84,6 @@ exclude = [ ...@@ -84,7 +84,6 @@ exclude = [
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"] "vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] "vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import copy import copy
from collections import defaultdict from collections import defaultdict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union from typing import Any, Callable, Optional, TypeAlias, Union
import pandas as pd import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
...@@ -20,7 +20,7 @@ from vllm.profiler.utils import (TablePrinter, event_has_module, ...@@ -20,7 +20,7 @@ from vllm.profiler.utils import (TablePrinter, event_has_module,
class _ModuleTreeNode: class _ModuleTreeNode:
event: _ProfilerEvent event: _ProfilerEvent
parent: Optional['_ModuleTreeNode'] = None parent: Optional['_ModuleTreeNode'] = None
children: List['_ModuleTreeNode'] = field(default_factory=list) children: list['_ModuleTreeNode'] = field(default_factory=list)
trace: str = "" trace: str = ""
@property @property
...@@ -60,19 +60,19 @@ StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry] ...@@ -60,19 +60,19 @@ StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
@dataclass @dataclass
class _StatsTreeNode: class _StatsTreeNode:
entry: StatsEntry entry: StatsEntry
children: List[StatsEntry] children: list[StatsEntry]
parent: Optional[StatsEntry] parent: Optional[StatsEntry]
@dataclass @dataclass
class LayerwiseProfileResults(profile): class LayerwiseProfileResults(profile):
_kineto_results: _ProfilerResult _kineto_results: _ProfilerResult
_kineto_event_correlation_map: Dict[int, _kineto_event_correlation_map: dict[int,
List[_KinetoEvent]] = field(init=False) list[_KinetoEvent]] = field(init=False)
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False) _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
_module_tree: List[_ModuleTreeNode] = field(init=False) _module_tree: list[_ModuleTreeNode] = field(init=False)
_model_stats_tree: List[_StatsTreeNode] = field(init=False) _model_stats_tree: list[_StatsTreeNode] = field(init=False)
_summary_stats_tree: List[_StatsTreeNode] = field(init=False) _summary_stats_tree: list[_StatsTreeNode] = field(init=False)
# profile metadata # profile metadata
num_running_seqs: Optional[int] = None num_running_seqs: Optional[int] = None
...@@ -82,7 +82,7 @@ class LayerwiseProfileResults(profile): ...@@ -82,7 +82,7 @@ class LayerwiseProfileResults(profile):
self._build_module_tree() self._build_module_tree()
self._build_stats_trees() self._build_stats_trees()
def print_model_table(self, column_widths: Dict[str, int] = None): def print_model_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=60, _column_widths = dict(name=60,
cpu_time_us=12, cpu_time_us=12,
cuda_time_us=12, cuda_time_us=12,
...@@ -100,7 +100,7 @@ class LayerwiseProfileResults(profile): ...@@ -100,7 +100,7 @@ class LayerwiseProfileResults(profile):
filtered_model_table, filtered_model_table,
indent_style=lambda indent: "|" + "-" * indent + " ")) indent_style=lambda indent: "|" + "-" * indent + " "))
def print_summary_table(self, column_widths: Dict[str, int] = None): def print_summary_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=80, _column_widths = dict(name=80,
cuda_time_us=12, cuda_time_us=12,
pct_cuda_time=12, pct_cuda_time=12,
...@@ -142,7 +142,7 @@ class LayerwiseProfileResults(profile): ...@@ -142,7 +142,7 @@ class LayerwiseProfileResults(profile):
} }
@staticmethod @staticmethod
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int, def _indent_row_names_based_on_depth(depths_rows: list[tuple[int,
StatsEntry]], StatsEntry]],
indent_style: Union[Callable[[int], indent_style: Union[Callable[[int],
str], str],
...@@ -229,7 +229,7 @@ class LayerwiseProfileResults(profile): ...@@ -229,7 +229,7 @@ class LayerwiseProfileResults(profile):
[self._cumulative_cuda_time(root) for root in self._module_tree]) [self._cumulative_cuda_time(root) for root in self._module_tree])
def _build_stats_trees(self): def _build_stats_trees(self):
summary_dict: Dict[str, _StatsTreeNode] = {} summary_dict: dict[str, _StatsTreeNode] = {}
total_cuda_time = self._total_cuda_time() total_cuda_time = self._total_cuda_time()
def pct_cuda_time(cuda_time_us): def pct_cuda_time(cuda_time_us):
...@@ -238,7 +238,7 @@ class LayerwiseProfileResults(profile): ...@@ -238,7 +238,7 @@ class LayerwiseProfileResults(profile):
def build_summary_stats_tree_df( def build_summary_stats_tree_df(
node: _ModuleTreeNode, node: _ModuleTreeNode,
parent: Optional[_StatsTreeNode] = None, parent: Optional[_StatsTreeNode] = None,
summary_trace: Tuple[str] = ()): summary_trace: tuple[str] = ()):
if event_has_module(node.event): if event_has_module(node.event):
name = event_module_repr(node.event) name = event_module_repr(node.event)
...@@ -313,8 +313,8 @@ class LayerwiseProfileResults(profile): ...@@ -313,8 +313,8 @@ class LayerwiseProfileResults(profile):
self._model_stats_tree.append(build_model_stats_tree_df(root)) self._model_stats_tree.append(build_model_stats_tree_df(root))
def _flatten_stats_tree( def _flatten_stats_tree(
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]: self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]:
entries: List[Tuple[int, StatsEntry]] = [] entries: list[tuple[int, StatsEntry]] = []
def df_traversal(node: _StatsTreeNode, depth=0): def df_traversal(node: _StatsTreeNode, depth=0):
entries.append((depth, node.entry)) entries.append((depth, node.entry))
...@@ -327,10 +327,10 @@ class LayerwiseProfileResults(profile): ...@@ -327,10 +327,10 @@ class LayerwiseProfileResults(profile):
return entries return entries
def _convert_stats_tree_to_dict(self, def _convert_stats_tree_to_dict(self,
tree: List[_StatsTreeNode]) -> List[Dict]: tree: list[_StatsTreeNode]) -> list[dict]:
root_dicts: List[Dict] = [] root_dicts: list[dict] = []
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
curr_json_list.append({ curr_json_list.append({
"entry": asdict(node.entry), "entry": asdict(node.entry),
"children": [] "children": []
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import dataclasses import dataclasses
from typing import Callable, Dict, List, Type, Union from typing import Callable, Union
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
...@@ -30,14 +30,14 @@ def trim_string_back(string, width): ...@@ -30,14 +30,14 @@ def trim_string_back(string, width):
class TablePrinter: class TablePrinter:
def __init__(self, row_cls: Type[dataclasses.dataclass], def __init__(self, row_cls: type[dataclasses.dataclass],
column_widths: Dict[str, int]): column_widths: dict[str, int]):
self.row_cls = row_cls self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames) assert set(self.column_widths.keys()) == set(self.fieldnames)
def print_table(self, rows: List[dataclasses.dataclass]): def print_table(self, rows: list[dataclasses.dataclass]):
self._print_header() self._print_header()
self._print_line() self._print_line()
for row in rows: for row in rows:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment