Commit 26f8ae57 authored by PanZezhong's avatar PanZezhong
Browse files

issue/60: to_tensor存储原torch张量,增加INFINI_ROOT默认路径

parent 89e49e31
......@@ -5,11 +5,12 @@ import ctypes
from ctypes import c_int, c_int64, c_uint64, Structure, POINTER, c_size_t
from .datatypes import *
from .devices import *
from pathlib import Path
Device = c_int
Optype = c_int
INFINI_ROOT = os.environ.get("INFINI_ROOT")
INFINI_ROOT = os.getenv("INFINI_ROOT") or str(Path.home() / ".infini")
class TensorDescriptor(Structure):
......@@ -30,9 +31,10 @@ infiniopTensorDescriptor_t = ctypes.POINTER(TensorDescriptor)
class CTensor:
def __init__(self, desc, data):
def __init__(self, desc, torch_tensor):
self.descriptor = desc
self.data = data
self.torch_tensor_ = torch_tensor
self.data = torch_tensor.data_ptr()
class Handle(Structure):
......
......@@ -19,7 +19,6 @@ def to_tensor(tensor, lib):
ndim = tensor.ndimension()
shape = (ctypes.c_size_t * ndim)(*tensor.shape)
strides = (ctypes.c_int64 * ndim)(*(tensor.stride()))
data_ptr = tensor.data_ptr()
# fmt: off
dt = (
InfiniDtype.I8 if tensor.dtype == torch.int8 else
......@@ -46,7 +45,7 @@ def to_tensor(tensor, lib):
ctypes.byref(tensor_desc), ndim, shape, strides, dt
)
# Create Tensor
return CTensor(tensor_desc, data_ptr)
return CTensor(tensor_desc, tensor)
def create_workspace(size, torch_device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment