Unverified Commit 48b3afd7 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/561: 将调用 ntops 算子设为默认

* 以变量形式使用 `use_ntops`

* 移除不需要的 `ntops.py`
parent c76c0645
import contextlib
from infinicore.device import device from infinicore.device import device
from infinicore.dtype import ( from infinicore.dtype import (
bfloat16, bfloat16,
...@@ -24,7 +26,6 @@ from infinicore.dtype import ( ...@@ -24,7 +26,6 @@ from infinicore.dtype import (
short, short,
uint8, uint8,
) )
from infinicore.ntops import use_ntops
from infinicore.ops.add import add from infinicore.ops.add import add
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
...@@ -68,8 +69,6 @@ __all__ = [ ...@@ -68,8 +69,6 @@ __all__ = [
"long", "long",
"short", "short",
"uint8", "uint8",
# `ntops` integration.
"use_ntops",
# Operations. # Operations.
"add", "add",
"attention", "attention",
...@@ -82,3 +81,15 @@ __all__ = [ ...@@ -82,3 +81,15 @@ __all__ = [
"strided_from_blob", "strided_from_blob",
"zeros", "zeros",
] ]
use_ntops = False
with contextlib.suppress(ImportError, ModuleNotFoundError):
import sys
import ntops
for op_name in ntops.torch.__all__:
getattr(ntops.torch, op_name).__globals__["torch"] = sys.modules[__name__]
use_ntops = True
import sys
import infinicore
def use_ntops():
import ntops
return _TemporaryAttributes(
tuple(
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
for op_name in ntops.torch.__all__
)
+ tuple(
(f"ntops.torch.{op_name}.__globals__['torch']", infinicore)
for op_name in ntops.torch.__all__
)
)
class _TemporaryAttributes:
def __init__(self, attribute_mappings):
self._attribute_mappings = attribute_mappings
self._original_values = {}
def __enter__(self):
for attr_path, new_value in self._attribute_mappings:
parent, attr_name, is_dict_key = self._resolve_path(attr_path)
try:
if is_dict_key:
self._original_values[attr_path] = parent.__globals__[attr_name]
else:
self._original_values[attr_path] = getattr(parent, attr_name)
except (AttributeError, KeyError):
pass
if is_dict_key:
parent.__globals__[attr_name] = new_value
else:
setattr(parent, attr_name, new_value)
return self
def __exit__(self, exc_type, exc_value, traceback):
for attr_path, _ in self._attribute_mappings:
parent, attr_name, is_dict_key = self._resolve_path(attr_path)
if attr_path in self._original_values:
original_value = self._original_values[attr_path]
if is_dict_key:
parent.__globals__[attr_name] = original_value
else:
setattr(parent, attr_name, original_value)
else:
if is_dict_key:
if attr_name in parent.__globals__.keys():
del parent.__globals__[attr_name]
else:
if parent is not None and attr_name is not None:
delattr(parent, attr_name)
@staticmethod
def _resolve_path(path):
is_dict_key = False
dict_key_match = None
if path.endswith("']"):
try:
start_index = path.rindex("['")
end_index = path.rindex("']")
if start_index > 0 and end_index == len(path) - 2:
is_dict_key = True
dict_key_match = path[start_index + 2 : end_index]
path = path[:start_index]
except ValueError:
pass
*parent_parts, attr_name = path.split(".")
curr = sys.modules[parent_parts[0]]
for part in parent_parts[1:]:
curr = getattr(curr, part)
parent = curr
if is_dict_key:
return parent, dict_key_match, True
else:
return parent, attr_name, False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment