Unverified Commit 37c76a90 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/500: 在 `infinicore` Python 包中接入 `ntops`

parent a3c5f3aa
...@@ -24,6 +24,7 @@ from infinicore.dtype import ( ...@@ -24,6 +24,7 @@ from infinicore.dtype import (
short, short,
uint8, uint8,
) )
from infinicore.ntops import use_ntops
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.tensor import ( from infinicore.tensor import (
...@@ -62,6 +63,8 @@ __all__ = [ ...@@ -62,6 +63,8 @@ __all__ = [
"long", "long",
"short", "short",
"uint8", "uint8",
# `ntops` integration.
"use_ntops",
# Operations. # Operations.
"matmul", "matmul",
"rearrange", "rearrange",
......
import sys
import infinicore
def use_ntops():
import ntops
return _TemporaryAttributes(
(("ntops.torch.torch", infinicore),)
+ tuple(
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
for op_name in ntops.torch.__all__
)
)
class _TemporaryAttributes:
def __init__(self, attribute_mappings):
self._attribute_mappings = attribute_mappings
self._original_values = {}
def __enter__(self):
for attr_path, new_value in self._attribute_mappings:
parent, attr_name = self._resolve_path(attr_path)
try:
self._original_values[attr_path] = getattr(parent, attr_name)
except AttributeError:
pass
setattr(parent, attr_name, new_value)
return self
def __exit__(self, exc_type, exc_value, traceback):
for attr_path, _ in self._attribute_mappings:
parent, attr_name = self._resolve_path(attr_path)
if attr_path in self._original_values:
setattr(parent, attr_name, self._original_values[attr_path])
else:
delattr(parent, attr_name)
@staticmethod
def _resolve_path(path):
*parent_parts, attr_name = path.split(".")
curr = sys.modules[parent_parts[0]]
for part in parent_parts[1:]:
curr = getattr(curr, part)
return curr, attr_name
...@@ -32,7 +32,7 @@ class Tensor: ...@@ -32,7 +32,7 @@ class Tensor:
return self._underlying.ndim return self._underlying.ndim
def data_ptr(self): def data_ptr(self):
return self._underlying.data_ptr return self._underlying.data_ptr()
def size(self, dim=None): def size(self, dim=None):
if dim is None: if dim is None:
......
...@@ -17,7 +17,7 @@ inline void bind(py::module &m) { ...@@ -17,7 +17,7 @@ inline void bind(py::module &m) {
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); }) .def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); }) .def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<uintptr_t>(tensor->data()); }) .def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<std::uintptr_t>(tensor->data()); })
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); }) .def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
.def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); }) .def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); })
.def("numel", [](const Tensor &tensor) { return tensor->numel(); }) .def("numel", [](const Tensor &tensor) { return tensor->numel(); })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment