Commit e00dc7c6 authored by catchyrime's avatar catchyrime
Browse files

Add Seek

parents
Pipeline #121 canceled with stages
# We disable these warnings from flake8:
# F401: 'xxx' imported but unused
# F403: 'from xxx import *' used; unable to detect undefined names
# Well, there are too many constants from const.py
# We would never bother to list them one by one here...
from .const import * # noqa: F401, F403
from .exception import SeekException, SeekInstrException, SeekTODOException # noqa: F401
from .register import GprType, Align, Gpr, sext, new, GprSet # noqa: F401
from .register import exec, exec_lo, exec_hi, execz # noqa: F401
from .register import vcc, vcc_lo, vcc_hi, vccz, vcc_exec # noqa: F401
from .register import flat_scratch, flat_scratch_lo, flat_scratch_hi # noqa: F401
from .register import xnack_mask, xnack_mask_lo, xnack_mask_hi # noqa: F401
from .register import scc, m0, lds_direct, s_kernarg, load_to_lds, off, threadIdx, blockIdx # noqa: F401
from .instr import explicit_wait, explicit_uses # noqa: F401
"""
WARNING:
This file is auto-generated from `automation/generate_const.py`
If you are to modify anything, please modify that script.
"""
# flake8: noqa
SIZEOF_U8 = 1
SIZEOF_I8 = 1
SIZEOF_U16 = 2
SIZEOF_I16 = 2
SIZEOF_F16 = 2
SIZEOF_U32 = 4
SIZEOF_I32 = 4
SIZEOF_F32 = 4
SIZEOF_U64 = 8
SIZEOF_I64 = 8
SIZEOF_F64 = 8
SIZEOF_F32X2 = 8
SIZEOF_F64X2 = 16
SHIFT_U8 = 0
SHIFT_I8 = 0
SHIFT_U16 = 1
SHIFT_I16 = 1
SHIFT_F16 = 1
SHIFT_U32 = 2
SHIFT_I32 = 2
SHIFT_F32 = 2
SHIFT_U64 = 3
SHIFT_I64 = 3
SHIFT_F64 = 3
SHIFT_F32X2 = 3
SHIFT_F64X2 = 4
DWORD = 'DWORD'
BYTE_0 = 'BYTE_0'
BYTE_1 = 'BYTE_1'
BYTE_2 = 'BYTE_2'
BYTE_3 = 'BYTE_3'
WORD_0 = 'WORD_0'
WORD_1 = 'WORD_1'
UNUSED_PAD = 'UNUSED_PAD'
UNUSED_SEXT = 'UNUSED_SEXT'
UNUSED_PRESERVE = 'UNUSED_PRESERVE'
SHARED_BASE = 'shared_base'
SHARED_LIMIT = 'shared_limit'
PRIVATE_BASE = 'private_base'
PRIVATE_LIMIT = 'private_limit'
POPS_EXITING_WAVE_ID = 'pops_exiting_wave_id'
_INSTR_STR_WIDTH = 24
# EOF
class SeekException(Exception):
def __init__(self, *args):
super().__init__(*args)
class SeekInstrException(SeekException):
def __init__(self, *args):
super().__init__(*args)
class SeekTODOException(SeekException):
def __init__(self, *args):
super().__init__(*args)
def check(expr, *args, exc_type=SeekException):
if not expr:
raise exc_type(*args)
from __future__ import annotations
from typing import Optional, Dict, Any, List, Union, Hashable, Set, Tuple
from .const import _INSTR_STR_WIDTH
from .exception import SeekTODOException, check
from .utility import _Global, SrcLoc
from .register import Gpr, GprSet, vcc, vccz, exec, execz, m0, lds_direct, load_to_lds
import enum
class MemTokenObject(object):
pass
class MemToken:
def __init__(self, token_object: Optional[Hashable],
store_mem_from_gprs: Set[Gpr],
load_mem_to_gprs: Set[Gpr],
srcloc: SrcLoc,
/,
inc_vector: int = 0,
inc_lds: int = 0,
inc_gds: int = 0,
inc_scalar: int = 0,
inc_msg: int = 0,
inc_export: int = 0):
if token_object is None:
token_object = MemTokenObject() # create a unique object
check(not isinstance(token_object, MemToken), "Don't use MemToken as token_object")
self.__token_object = token_object # type: Hashable
for gpr in store_mem_from_gprs:
assert isinstance(gpr, Gpr)
assert not gpr.is_neg and not gpr.is_abs and not gpr.is_sext
self.__store_mem_from_gprs = list(sorted(store_mem_from_gprs))
for gpr in load_mem_to_gprs:
assert isinstance(gpr, Gpr)
assert not gpr.is_neg and not gpr.is_abs and not gpr.is_sext
self.__load_mem_to_gprs = list(sorted(load_mem_to_gprs))
assert srcloc is not None
assert inc_vector >= 0
assert inc_lds >= 0
assert inc_gds >= 0
assert inc_scalar >= 0
assert inc_msg >= 0
assert inc_export >= 0
self.__srcloc = srcloc # type: SrcLoc
self.__inc_vector = inc_vector # type: int
self.__inc_lds = inc_lds # type: int
self.__inc_gds = inc_gds # type: int
self.__inc_scalar = inc_scalar # type: int
self.__inc_msg = inc_msg # type: int
self.__inc_export = inc_export # type: int
@property
def token_object(self) -> Hashable:
assert self.__token_object is not None
return self.__token_object
@property
def srcloc(self) -> SrcLoc:
return self.__srcloc
@property
def inc_vector(self) -> int:
return self.__inc_vector
@property
def inc_lds(self) -> int:
return self.__inc_lds
@property
def inc_gds(self) -> int:
return self.__inc_gds
@property
def inc_scalar(self) -> int:
return self.__inc_scalar
@property
def inc_msg(self) -> int:
return self.__inc_msg
@property
def inc_export(self) -> int:
return self.__inc_export
@property
def total_inc_vmcnt(self):
return self.inc_vector
@property
def total_inc_lgkmcnt(self):
return self.inc_lds + self.inc_gds + self.inc_scalar + self.inc_msg
@property
def total_inc_expcnt(self):
return self.inc_export
@property
def store_mem_from_gprs(self) -> List[Gpr]:
return self.__store_mem_from_gprs
@property
def load_mem_to_gprs(self) -> List[Gpr]:
return self.__load_mem_to_gprs
def __repr__(self):
parts = [] # type: List[str]
if not isinstance(self.token_object, MemTokenObject):
parts.append(f"token_object={self.token_object!r}")
if self.inc_vector > 0:
parts.append(f"vmcnt={self.inc_vector}")
if self.inc_lds > 0:
parts.append(f"lgkmcnt_lds={self.inc_lds}")
if self.inc_gds > 0:
parts.append(f"lgkmcnt_gds={self.inc_gds}")
if self.inc_msg > 0:
parts.append(f"lgkmcnt_msg={self.inc_msg}")
if self.inc_scalar > 0:
parts.append(f"lgkmcnt_smem={self.inc_scalar}")
if self.inc_export > 0:
parts.append(f"expcnt={self.inc_export}")
if self.store_mem_from_gprs:
parts.append(f"store_mem_from_gprs={self.store_mem_from_gprs}")
if self.load_mem_to_gprs:
parts.append(f"load_mem_to_gprs={self.load_mem_to_gprs}")
parts.append(f"srcloc={self.srcloc!r}")
return f"MemToken({', '.join(parts)})"
class Waitcnt:
__slots__ = ["__vmcnt", "__lgkmcnt", "__expcnt"]
def __init__(self, *,
vmcnt: Optional[int] = None,
lgkmcnt: Optional[int] = None,
expcnt: Optional[int] = None):
check(vmcnt is not None or lgkmcnt is not None or expcnt is not None,
"vmcnt, lgkmcnt and expcnt can't all be None")
if vmcnt is not None:
check(0 <= vmcnt < (1 << 6), f"Invalid vmcnt: {vmcnt}")
if lgkmcnt is not None:
check(0 <= lgkmcnt < (1 << 4), f"Invalid lgkmcnt: {lgkmcnt}")
if expcnt is not None:
check(0 <= expcnt < (1 << 3), f"Invalid expcnt: {expcnt}")
self.__vmcnt = vmcnt
self.__lgkmcnt = lgkmcnt
self.__expcnt = expcnt
@property
def vmcnt(self) -> Optional[int]:
return self.__vmcnt
@property
def lgkmcnt(self) -> Optional[int]:
return self.__lgkmcnt
@property
def expcnt(self) -> Optional[int]:
return self.__expcnt
def waitcnt_str(self) -> str:
parts = [] # type: List[str]
if self.vmcnt is not None:
parts.append(f"vmcnt({self.vmcnt})")
if self.lgkmcnt is not None:
parts.append(f"lgkmcnt({self.lgkmcnt})")
if self.expcnt is not None:
parts.append(f"expcnt({self.expcnt})")
return ' & '.join(parts)
def __repr__(self):
waits = [] # type: List[str]
if self.vmcnt is not None:
waits.append(f"vmcnt={self.vmcnt}")
if self.lgkmcnt is not None:
waits.append(f"lgkmcnt={self.lgkmcnt}")
if self.expcnt is not None:
waits.append(f"expcnt={self.expcnt}")
return f"Waitcnt({', '.join(waits)})"
def __hash__(self):
return hash((self.vmcnt, self.lgkmcnt, self.expcnt))
def __eq__(self, other):
if not isinstance(other, Waitcnt):
return NotImplemented
return self.vmcnt == other.vmcnt and self.lgkmcnt == other.lgkmcnt and self.expcnt == other.expcnt
@staticmethod
def from_int(num: int) -> Waitcnt:
"""
The bits of this operand have the following meaning:
High Bits | Low Bits | Description | Value Range
> --- |--- |--- |---
> 15:14 | 3:0 | VM_CNT: vector memory operations count. | 0..63
> - | 6:4 | EXP_CNT: export count. | 0..7
> - | 11:8 | LGKM_CNT: LDS, GDS, Constant and Message count. | 0..15
"""
assert 0 <= num <= 65535 # totally 16 bits
vmcnt = ((num >> 14) & 0b11) << 4 | (num & 0b1111)
expcnt = (num >> 4) & 0b111
lgkmcnt = (num >> 8) & 0b1111
return Waitcnt(vmcnt=vmcnt, lgkmcnt=lgkmcnt, expcnt=expcnt)
@staticmethod
def lcm(wc1: Optional[Waitcnt], wc2: Optional[Waitcnt], /) -> Waitcnt:
if wc1 is None:
return wc2
if wc2 is None:
return wc1
if wc1.vmcnt is None:
final_vmcnt = wc2.vmcnt
elif wc2.vmcnt is None:
final_vmcnt = wc1.vmcnt
else:
final_vmcnt = min(wc1.vmcnt, wc2.vmcnt)
if wc1.lgkmcnt is None:
final_lgkmcnt = wc2.lgkmcnt
elif wc2.lgkmcnt is None:
final_lgkmcnt = wc1.lgkmcnt
else:
final_lgkmcnt = min(wc1.lgkmcnt, wc2.lgkmcnt)
if wc1.expcnt is None:
final_expcnt = wc2.expcnt
elif wc2.expcnt is None:
final_expcnt = wc1.expcnt
else:
final_expcnt = min(wc1.expcnt, wc2.expcnt)
return Waitcnt(vmcnt=final_vmcnt, lgkmcnt=final_lgkmcnt, expcnt=final_expcnt)
class ControlFlowEnum(enum.Enum):
AlwaysJump = enum.auto() # Unconditional branching, s_branch
CondJump = enum.auto() # Conditional branching, s_cbranch_xxx
Terminate = enum.auto() # Terminate the program, s_endpgm
Fallthrough = enum.auto() # Fall through to next basic-block
class InstrCall:
def __init__(self,
category: str,
instr_name: str,
operands: Dict[str, Any],
modifiers: Dict[str, Any],
mem_token: Optional[MemToken],
comment: Optional[str],
srcloc: SrcLoc):
assert srcloc is not None
self.category = category
self.instr_name = instr_name
self.operands = operands
self.modifiers = modifiers
self.mem_token = mem_token
self.comment = comment
self.srcloc = srcloc
self.gpr_uses = {} # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_holds = {} # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_defs = {} # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
def add_to_current_block(self):
# A few sanity checks
# Gpr in `gpr_holds` must exists in both `gpr_uses` and `gpr_defs`
holds_gprset = self.gpr_holds_to_gprset()
assert self.gpr_uses_to_gprset().is_superset(holds_gprset)
assert self.gpr_defs_to_gprset().is_superset(holds_gprset)
# Add to current block
_Global.current_block().add_instr_call(self)
# noinspection PyMethodMayBeStatic
def __add_gpr_uses_holds_defs(self, gpr_dict: Dict[Gpr, Dict[int, int]], *values: Any):
for value in values:
if isinstance(value, Gpr):
if value.base_gpr not in gpr_dict:
gpr_dict[value.base_gpr] = dict() # type: Dict[int, int] # {offset -> count}
gpr_dict_for_value = gpr_dict[value.base_gpr]
# If `value.base_offset` does not exist yet, add {value.base_offset -> value.count}
# Otherwise, update the count if `value.count` is larger than the original count
if gpr_dict_for_value.setdefault(value.base_offset, value.count) < value.count:
gpr_dict_for_value[value.base_offset] = value.count
@staticmethod
def gpr_uses_holds_defs_to_gprset(gpr_dict: Dict[Gpr, Dict[int, int]]) -> GprSet:
result = GprSet()
for base_gpr, dict_offset_count in gpr_dict.items(): # base_gpr -> {offset -> count}
result.union_update(*[base_gpr[offset:offset+count-1]
for offset, count in dict_offset_count.items()])
return result
def gpr_uses_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_uses)
def gpr_holds_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_holds)
def gpr_defs_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_defs)
def add_gpr_uses(self, *values: Any):
self.__add_gpr_uses_holds_defs(self.gpr_uses, *values)
def add_gpr_holds(self, *values: Any):
self.__add_gpr_uses_holds_defs(self.gpr_holds, *values)
def add_gpr_defs(self, *values: Any):
self.__add_gpr_uses_holds_defs(self.gpr_defs, *values)
# Some special treats
more_defs = []
for value in (x for x in values if isinstance(x, Gpr)):
if value.base_gpr is vcc:
more_defs.append(vccz) # If we defined `vcc`, we defined `vccz` too
if value.base_gpr is exec:
more_defs.append(execz) # If we defined `exec`, we defined `execz` too
if value.base_gpr is m0:
more_defs.append(lds_direct) # If we defined `m0`, we defined `lds_direct` `load_to_lds` too
more_defs.append(load_to_lds)
self.__add_gpr_uses_holds_defs(self.gpr_defs, more_defs)
@property
def control_flow_enum(self) -> ControlFlowEnum:
if self.instr_name in {"s_cbranch_g_fork", "s_cbranch_i_fork", "s_cbranch_join"}:
raise SeekTODOException("fork & join not supported yet")
if self.instr_name == "s_branch":
return ControlFlowEnum.AlwaysJump
elif self.instr_name.startswith("s_cbranch"):
return ControlFlowEnum.CondJump
elif self.instr_name.startswith("s_endpgm"):
return ControlFlowEnum.Terminate
else:
return ControlFlowEnum.Fallthrough
def __repr__(self):
def __repr_gpr_uses_holds_defs(gpr_dict: Dict[Gpr, Dict[int, int]]) -> str:
gpr_parts = []
for base_gpr, gpr_dict_for_value in gpr_dict.items():
assert not base_gpr.is_view
assert len(gpr_dict_for_value) > 0
sub_parts = []
for base_offset, count in sorted(gpr_dict_for_value.items()):
assert count >= 1
if count == 1:
sub_parts.append(f"{base_offset}")
else:
sub_parts.append(f"{base_offset}:{base_offset+count-1}")
gpr_parts.append(f"{base_gpr}[{'|'.join(sub_parts)}]")
return "{" + ", ".join(gpr_parts) + "}"
# noinspection PyListCreation
parts = [] # type: List[str]
parts.append(f"{self.instr_name!r}")
if self.operands:
parts.append(f"operands={{{', '.join(f'/*{k}*/{v!r}' for k, v in self.operands.items())}}}")
if self.modifiers:
parts.append(f"modifiers={{{', '.join(f'{k}={v!r}' for k, v in self.modifiers.items())}}}")
if self.mem_token is not None:
parts.append(f"mem_token={self.mem_token}")
if self.comment is not None:
parts.append(f"comment={self.comment!r}") # string will be properly escaped by repr()
parts.append(f"gpr_uses={__repr_gpr_uses_holds_defs(self.gpr_uses)}")
parts.append(f"gpr_holds={__repr_gpr_uses_holds_defs(self.gpr_holds)}")
parts.append(f"gpr_defs={__repr_gpr_uses_holds_defs(self.gpr_defs)}")
parts.append(f"srcloc={self.srcloc!r}")
return f"InstrCall({', '.join(parts)})"
def generate(self, program, wr):
result = f"{self.instr_name} ".ljust(_INSTR_STR_WIDTH)
def operand_str(v: Any) -> str:
if isinstance(v, int):
return str(v)
if isinstance(v, float):
return str(v)
if isinstance(v, Gpr):
return f"{v.to_physical_str(program)}/*{v!r}*/"
if self.operands:
result += ', '.join(operand_str(v) for v in self.operands.values())
if any(True for v in self.modifiers.values() if v is not None):
result += " " + ' '.join(f'{k}={v!r}' for k, v in self.modifiers.items() if v is not None)
if self.comment is not None:
result += f" // {self.comment}"
# if self.mem_token is not None:
# parts.append(f"mem_token={self.mem_token}")
wr(result)
class ExplicitWaitCall:
__slots__ = ["__mem_token_or_token_objects", "gpr_uses", "gpr_holds", "gpr_defs"]
def __init__(self, *mem_token_or_token_objects: Union[MemToken, Any]):
assert mem_token_or_token_objects is not None
self.__mem_token_or_token_objects = mem_token_or_token_objects # type: Tuple[Union[MemToken, Any]]
self.gpr_uses = dict() # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_holds = dict() # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_defs = dict() # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
def __repr__(self):
return f"ExplicitWaitCall({{{', '.join(map(repr, self.__mem_token_or_token_objects))}}})"
@property
def mem_token_or_token_objects(self) -> Tuple[Union[MemToken, Any]]:
return self.__mem_token_or_token_objects
def gpr_uses_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_uses)
def gpr_holds_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_holds)
def gpr_defs_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_defs)
def generate(self, program, wr):
wr(f"/* explicit_wait: {', '.join(map(repr, self.__mem_token_or_token_objects))} */")
class ExplicitUsesCall:
__slots__ = ["__uses", "gpr_uses", "gpr_holds", "gpr_defs"]
def __init__(self, *uses: Gpr):
assert uses is not None
self.__uses = uses # type: Tuple[Gpr]
self.gpr_uses = dict() # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_holds = dict() # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
self.gpr_defs = dict() # type: Dict[Gpr, Dict[int, int]] # base_gpr -> {offset -> count}
for value in uses:
assert isinstance(value, Gpr)
if value.base_gpr not in self.gpr_uses:
self.gpr_uses[value.base_gpr] = dict() # type: Dict[int, int] # {offset -> count}
gpr_dict_for_value = self.gpr_uses[value.base_gpr]
# If `value.base_offset` does not exist yet, add {value.base_offset -> value.count}
# Otherwise, update the count if `value.count` is larger than the original count
if gpr_dict_for_value.setdefault(value.base_offset, value.count) < value.count:
gpr_dict_for_value[value.base_offset] = value.count
def __repr__(self):
return f"ExplicitUsesCall({{{', '.join(map(repr, self.__uses))}}})"
@property
def uses(self) -> Tuple[Gpr]:
return self.__uses
def gpr_uses_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_uses)
def gpr_holds_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_holds)
def gpr_defs_to_gprset(self) -> GprSet:
return InstrCall.gpr_uses_holds_defs_to_gprset(self.gpr_defs)
def generate(self, program, wr):
wr(f"/* explicit_uses: {list(self.uses)} */")
class Block:
def __init__(self, block_name: str):
self.block_name = block_name # type: str
self.clauses = [] # type: List[Union[InstrCall, ExplicitWaitCall, ExplicitUsesCall]]
def __enter__(self) -> Block:
# noinspection PyProtectedMember
_Global._current_block_stack.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# noinspection PyProtectedMember
block = _Global._current_block_stack.pop()
assert block is self
def __repr__(self):
return f"Block({self.block_name!r})"
def add_instr_call(self, call: InstrCall):
assert call is not None
assert isinstance(call, InstrCall)
self.clauses.append(call)
def add_explicit_wait(self, call: ExplicitWaitCall):
assert isinstance(call, ExplicitWaitCall)
self.clauses.append(call)
def add_explicit_uses(self, call: ExplicitUsesCall):
assert isinstance(call, ExplicitUsesCall)
self.clauses.append(call)
def explicit_wait(*mem_token_or_token_objects: Union[MemToken, Any]):
for mem_token_or_token_object in mem_token_or_token_objects:
check(mem_token_or_token_object is not None)
call = ExplicitWaitCall(*mem_token_or_token_objects)
_Global.current_block().add_explicit_wait(call)
def explicit_uses(*gprs: Gpr):
for gpr in gprs:
check(isinstance(gpr, Gpr))
call = ExplicitUsesCall(*gprs)
_Global.current_block().add_explicit_uses(call)
This diff is collapsed.
from __future__ import annotations
import os.path
from .exception import SeekException, SeekTODOException
from typing import List, Union, Tuple, Any, Optional, TextIO
import inspect
import textwrap
import ast
import astpretty
class _Global:
"""
A global namespace for Seek internal usage
"""
_current_block_stack = [] # type: List["Block"] # noqa: F821 (undefined name 'Block')
@staticmethod
def current_block() -> "Block": # noqa: F821 (undefined name 'Block')
if len(_Global._current_block_stack) == 0:
raise SeekException("No active current block entered")
return _Global._current_block_stack[-1]
def _inspect_stack_frame(prev_at: int):
frame = inspect.currentframe()
assert frame is not None, "Unimplemented inspect.currentframe()"
for _ in range(0, prev_at + 1):
frame = frame.f_back
if frame is None:
return None
return inspect.FrameInfo(frame, *inspect.getframeinfo(frame, 1))
class SrcLoc:
__slots__ = ["__filename", "__lineno"]
def __init__(self, filename: str, lineno: int):
self.__filename = os.path.basename(filename)
self.__lineno = lineno
@property
def filename(self) -> str:
return self.__filename
@property
def lineno(self) -> int:
return self.__lineno
def __repr__(self):
return f"SrcLoc(filename={self.filename!r}, lineno={self.lineno})"
def __str__(self):
return f"{self.filename}:{self.lineno}"
def __eq__(self, other):
if not isinstance(other, SrcLoc):
return NotImplemented
return (self.filename, self.lineno) == (other.filename, other.lineno)
def __hash__(self):
return hash((self.filename, self.lineno))
@staticmethod
def get_caller_srcloc(stack: int = 1) -> SrcLoc:
assert stack >= 0, f"Unknown stack depth: {stack}"
frame = _inspect_stack_frame(stack + 1) # plus myself
if frame is None:
raise SeekException(f"Can't get frame at stack depth {stack} from bottom")
return SrcLoc(frame.filename, frame.lineno)
def get_caller_assignments(stack: int = 1) -> Optional[Union[str, Tuple, List]]:
"""
Get the caller line; try to split it and get assignment targets
"""
assert stack >= 0, f"Unknown stack depth: {stack}"
frame = _inspect_stack_frame(stack + 1) # plus myself
if frame is None:
raise SeekException(f"Can't get frame at stack depth {stack} from bottom")
# Seems `code_context` contains just 1 line (at least for Python3)
# See: https://stackoverflow.com/questions/58720279
source = "".join(frame.code_context)
# De-indent this line, if necessary
source = textwrap.dedent(source)
try:
root_node = ast.parse(source, filename=frame.filename, mode="exec")
except SyntaxError:
return None
assert isinstance(root_node, ast.Module)
body = root_node.body
assert isinstance(body, list)
assert len(body) == 1
body0 = body[0]
def get_target(target: Any):
if isinstance(target, ast.Name):
return target.id
elif isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name):
return f"{target.value.id}.{target.attr}"
elif isinstance(target, ast.Tuple):
return tuple(get_target(elt) for elt in target.elts)
elif isinstance(target, ast.List):
return list(get_target(elt) for elt in target.elts)
else:
astpretty.pprint(target)
raise SeekTODOException(target)
if isinstance(body0, ast.Assign): # (simple) assignment
assign_targets = body0.targets
assert len(assign_targets) > 0
if len(assign_targets) > 1:
# Case like: `a = b = get_caller_assignments(stack=0)`
raise SeekException("Assignment to multiple targets is not supported")
return get_target(assign_targets[0])
elif isinstance(body0, ast.AnnAssign): # annotated assignment
return get_target(body0.target)
else:
# This is not an assignment
return None
class IndentedWriter:
def __init__(self, fp: TextIO, indent: int = 0):
self._fp = fp
self._indent = indent
def __call__(self, s: str = ""):
s = textwrap.indent(s + "\n", ' ' * self._indent)
self._fp.write(s)
def indent(self, num: int = 4):
class Indenter:
def __init__(self, writer: IndentedWriter):
self.writer = writer
def __enter__(self):
self.writer._indent += num
def __exit__(self, exc_type, exc_val, exc_tb):
self.writer._indent -= num
return Indenter(self)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment