Unverified Commit 37c76a90 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/500: 在 `infinicore` Python 包中接入 `ntops`

parent a3c5f3aa
......@@ -24,6 +24,7 @@ from infinicore.dtype import (
short,
uint8,
)
from infinicore.ntops import use_ntops
from infinicore.ops.matmul import matmul
from infinicore.ops.rearrange import rearrange
from infinicore.tensor import (
......@@ -62,6 +63,8 @@ __all__ = [
"long",
"short",
"uint8",
# `ntops` integration.
"use_ntops",
# Operations.
"matmul",
"rearrange",
......
import sys
import infinicore
def use_ntops():
import ntops
return _TemporaryAttributes(
(("ntops.torch.torch", infinicore),)
+ tuple(
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
for op_name in ntops.torch.__all__
)
)
class _TemporaryAttributes:
def __init__(self, attribute_mappings):
self._attribute_mappings = attribute_mappings
self._original_values = {}
def __enter__(self):
for attr_path, new_value in self._attribute_mappings:
parent, attr_name = self._resolve_path(attr_path)
try:
self._original_values[attr_path] = getattr(parent, attr_name)
except AttributeError:
pass
setattr(parent, attr_name, new_value)
return self
def __exit__(self, exc_type, exc_value, traceback):
for attr_path, _ in self._attribute_mappings:
parent, attr_name = self._resolve_path(attr_path)
if attr_path in self._original_values:
setattr(parent, attr_name, self._original_values[attr_path])
else:
delattr(parent, attr_name)
@staticmethod
def _resolve_path(path):
*parent_parts, attr_name = path.split(".")
curr = sys.modules[parent_parts[0]]
for part in parent_parts[1:]:
curr = getattr(curr, part)
return curr, attr_name
......@@ -32,7 +32,7 @@ class Tensor:
return self._underlying.ndim
def data_ptr(self):
return self._underlying.data_ptr
return self._underlying.data_ptr()
def size(self, dim=None):
if dim is None:
......
......@@ -17,7 +17,7 @@ inline void bind(py::module &m) {
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<uintptr_t>(tensor->data()); })
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<std::uintptr_t>(tensor->data()); })
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
.def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); })
.def("numel", [](const Tensor &tensor) { return tensor->numel(); })
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment