Commit f221ff39 authored by You Jiacheng's avatar You Jiacheng Committed by LeiWang1999
Browse files

[Language] make linter and type checker happy with mocking (#407)



* [Language] make linter and type checker happy with mocking

* Apply suggestions from code review
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>

* Refactor BaseTensor class in proxy.py to implement __getitem__ and __setitem__ methods, enhancing type checking and linting compliance. Added method stubs for from_ptr and other subclasses for improved clarity and maintainability.

* Refactor type imports in proxy.py to enhance clarity and maintainability by replacing the built-in Self type with the Self type from typing_extensions.

---------
Co-authored-by: default avatarCopilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 5c7e2fa8
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Optional
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING
from typing_extensions import Self
from tvm import tir
from tvm.tir import Var, PrimExpr
......@@ -171,10 +172,57 @@ Buffer = BufferProxy() # pylint: disable=invalid-name
# Tensor is an alias for Buffer
# Because when user do jit compile, the input and output will
# be mapped with torch.Tensor.
Tensor = TensorProxy() # pylint: disable=invalid-name
FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
if TYPE_CHECKING:
class BaseTensor:
def __class_getitem__(cls, key):
return cls
def __getitem__(self, key) -> Any:
...
def __setitem__(self, key, value) -> None:
...
def __init__(
self,
shape: Sequence[SupportsIndex],
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope=None, # Changed to None to use class default
align=None,
offset_factor=None,
buffer_type="",
axis_separators=None,
):
...
@classmethod
def from_ptr(cls,
pointer_var: Var,
shape: Sequence[PrimExpr, ...],
dtype: str = "float32") -> Self:
...
class Tensor(BaseTensor):
...
class FragmentBuffer(BaseTensor):
...
class SharedBuffer(BaseTensor):
...
class LocalBuffer(BaseTensor):
...
else:
Tensor = TensorProxy() # pylint: disable=invalid-name
FragmentBuffer = FragmentBufferProxy() # pylint: disable=invalid-name
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
def ptr(dtype: Optional[str] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment