Unverified Commit 1a618ff0 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #590 from pengcheng888/main

issue/588 - 为Tensor添加from_torch函数, +*运算符重载
parents 576b7552 e4b7e5ee
...@@ -37,6 +37,7 @@ from infinicore.tensor import ( ...@@ -37,6 +37,7 @@ from infinicore.tensor import (
empty, empty,
empty_like, empty_like,
from_blob, from_blob,
from_torch,
ones, ones,
strided_empty, strided_empty,
strided_from_blob, strided_from_blob,
...@@ -82,6 +83,7 @@ __all__ = [ ...@@ -82,6 +83,7 @@ __all__ = [
"empty", "empty",
"empty_like", "empty_like",
"from_blob", "from_blob",
"from_torch",
"ones", "ones",
"strided_empty", "strided_empty",
"strided_from_blob", "strided_from_blob",
......
...@@ -2,9 +2,11 @@ import infinicore.device ...@@ -2,9 +2,11 @@ import infinicore.device
import infinicore.dtype import infinicore.dtype
from infinicore.lib import _infinicore from infinicore.lib import _infinicore
from .utils import to_infinicore_dtype
class Tensor: class Tensor:
def __init__(self, underlying): def __init__(self, underlying, *, _torch_ref=None):
"""An internal method. Please do not use this directly.""" """An internal method. Please do not use this directly."""
self._underlying = underlying self._underlying = underlying
...@@ -15,6 +17,8 @@ class Tensor: ...@@ -15,6 +17,8 @@ class Tensor:
self._underlying.device self._underlying.device
) )
self._torch_ref = _torch_ref
@property @property
def shape(self): def shape(self):
return self._underlying.shape return self._underlying.shape
...@@ -86,6 +90,12 @@ class Tensor: ...@@ -86,6 +90,12 @@ class Tensor:
else: else:
self._underlying.debug(filename) self._underlying.debug(filename)
def __add__(self, other):
return infinicore.add(self, other)
def __mul__(self, other):
return infinicore.mul(self, other)
def empty(size, *, dtype=None, device=None, pin_memory=False): def empty(size, *, dtype=None, device=None, pin_memory=False):
return Tensor( return Tensor(
...@@ -135,3 +145,17 @@ def strided_from_blob(data_ptr, size, strides, *, dtype=None, device=None): ...@@ -135,3 +145,17 @@ def strided_from_blob(data_ptr, size, strides, *, dtype=None, device=None):
data_ptr, size, strides, dtype._underlying, device._underlying data_ptr, size, strides, dtype._underlying, device._underlying
) )
) )
def from_torch(torch_tensor) -> Tensor:
infini_type = to_infinicore_dtype(torch_tensor.dtype)
infini_device = infinicore.device(torch_tensor.device.type, 0)
return Tensor(
_infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=infini_type._underlying,
device=infini_device._underlying,
),
torch_ref=torch_tensor,
)
import torch
import infinicore
def to_torch_dtype(infini_dtype):
"""Convert infinicore data type to PyTorch data type"""
if infini_dtype == infinicore.float16:
return torch.float16
elif infini_dtype == infinicore.float32:
return torch.float32
elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16
elif infini_dtype == infinicore.int8:
return torch.int8
elif infini_dtype == infinicore.int16:
return torch.int16
elif infini_dtype == infinicore.int32:
return torch.int32
elif infini_dtype == infinicore.int64:
return torch.int64
elif infini_dtype == infinicore.uint8:
return torch.uint8
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
def to_infinicore_dtype(torch_dtype):
"""Convert PyTorch data type to infinicore data type"""
if torch_dtype == torch.float32:
return infinicore.float32
elif torch_dtype == torch.float16:
return infinicore.float16
elif torch_dtype == torch.bfloat16:
return infinicore.bfloat16
elif torch_dtype == torch.int8:
return infinicore.int8
elif torch_dtype == torch.int16:
return infinicore.int16
elif torch_dtype == torch.int32:
return infinicore.int32
elif torch_dtype == torch.int64:
return infinicore.int64
elif torch_dtype == torch.uint8:
return infinicore.uint8
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
import infinicore
import torch import torch
import infinicore
def test(): def test():
shape = [2, 3, 4] shape = [2, 3, 4]
...@@ -40,5 +41,45 @@ def test(): ...@@ -40,5 +41,45 @@ def test():
print("Test passed") print("Test passed")
def test2():
"测试infinicore.Tensor的from_torch, +* 运算符功能"
shape = [1, 2, 3]
x1_torch = torch.rand(shape, dtype=torch.float32, device="cpu")
x2_torch = torch.rand(shape, dtype=torch.float32, device="cpu")
x1_infini = infinicore.from_torch(x1_torch.clone())
x2_infini = infinicore.from_torch(x2_torch.clone())
ans1_infini = x1_infini + x2_infini
ans2_infini = x1_infini * x2_infini
ans1_torch_ref = x1_torch + x2_torch
ans2_torch_ref = x1_torch * x2_torch
print("----------------------------------------")
torch_ans1_result = torch.zeros(shape, dtype=torch.float32, device="cpu")
torch_ans2_result = torch.zeros(shape, dtype=torch.float32, device="cpu")
torch_ans1 = infinicore.from_blob(
torch_ans1_result.data_ptr(),
shape,
dtype=infinicore.float32,
device=infinicore.device("cpu", 0),
)
torch_ans2 = infinicore.from_blob(
torch_ans2_result.data_ptr(),
shape,
dtype=infinicore.float32,
device=infinicore.device("cpu", 0),
)
torch_ans1.copy_(ans1_infini)
torch_ans2.copy_(ans2_infini)
print("----------------------------------------")
print("abs error: ", torch.abs(ans1_torch_ref - torch_ans1_result).max())
print("abs error: ", torch.abs(ans2_torch_ref - torch_ans2_result).max())
if __name__ == "__main__": if __name__ == "__main__":
test() # test()
test2()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment