Unverified Commit 1a618ff0 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #590 from pengcheng888/main

issue/588 - 为Tensor添加from_torch函数, +*运算符重载
parents 576b7552 e4b7e5ee
......@@ -37,6 +37,7 @@ from infinicore.tensor import (
empty,
empty_like,
from_blob,
from_torch,
ones,
strided_empty,
strided_from_blob,
......@@ -82,6 +83,7 @@ __all__ = [
"empty",
"empty_like",
"from_blob",
"from_torch",
"ones",
"strided_empty",
"strided_from_blob",
......
......@@ -2,9 +2,11 @@ import infinicore.device
import infinicore.dtype
from infinicore.lib import _infinicore
from .utils import to_infinicore_dtype
class Tensor:
def __init__(self, underlying):
def __init__(self, underlying, *, _torch_ref=None):
"""An internal method. Please do not use this directly."""
self._underlying = underlying
......@@ -15,6 +17,8 @@ class Tensor:
self._underlying.device
)
self._torch_ref = _torch_ref
@property
def shape(self):
return self._underlying.shape
......@@ -86,6 +90,12 @@ class Tensor:
else:
self._underlying.debug(filename)
def __add__(self, other):
return infinicore.add(self, other)
def __mul__(self, other):
return infinicore.mul(self, other)
def empty(size, *, dtype=None, device=None, pin_memory=False):
return Tensor(
......@@ -135,3 +145,17 @@ def strided_from_blob(data_ptr, size, strides, *, dtype=None, device=None):
data_ptr, size, strides, dtype._underlying, device._underlying
)
)
def from_torch(torch_tensor) -> Tensor:
infini_type = to_infinicore_dtype(torch_tensor.dtype)
infini_device = infinicore.device(torch_tensor.device.type, 0)
return Tensor(
_infinicore.from_blob(
torch_tensor.data_ptr(),
list(torch_tensor.shape),
dtype=infini_type._underlying,
device=infini_device._underlying,
),
torch_ref=torch_tensor,
)
import torch
import infinicore
def to_torch_dtype(infini_dtype):
"""Convert infinicore data type to PyTorch data type"""
if infini_dtype == infinicore.float16:
return torch.float16
elif infini_dtype == infinicore.float32:
return torch.float32
elif infini_dtype == infinicore.bfloat16:
return torch.bfloat16
elif infini_dtype == infinicore.int8:
return torch.int8
elif infini_dtype == infinicore.int16:
return torch.int16
elif infini_dtype == infinicore.int32:
return torch.int32
elif infini_dtype == infinicore.int64:
return torch.int64
elif infini_dtype == infinicore.uint8:
return torch.uint8
else:
raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}")
def to_infinicore_dtype(torch_dtype):
"""Convert PyTorch data type to infinicore data type"""
if torch_dtype == torch.float32:
return infinicore.float32
elif torch_dtype == torch.float16:
return infinicore.float16
elif torch_dtype == torch.bfloat16:
return infinicore.bfloat16
elif torch_dtype == torch.int8:
return infinicore.int8
elif torch_dtype == torch.int16:
return infinicore.int16
elif torch_dtype == torch.int32:
return infinicore.int32
elif torch_dtype == torch.int64:
return infinicore.int64
elif torch_dtype == torch.uint8:
return infinicore.uint8
else:
raise ValueError(f"Unsupported torch dtype: {torch_dtype}")
import infinicore
import torch
import infinicore
def test():
shape = [2, 3, 4]
......@@ -40,5 +41,45 @@ def test():
print("Test passed")
def test2():
"测试infinicore.Tensor的from_torch, +* 运算符功能"
shape = [1, 2, 3]
x1_torch = torch.rand(shape, dtype=torch.float32, device="cpu")
x2_torch = torch.rand(shape, dtype=torch.float32, device="cpu")
x1_infini = infinicore.from_torch(x1_torch.clone())
x2_infini = infinicore.from_torch(x2_torch.clone())
ans1_infini = x1_infini + x2_infini
ans2_infini = x1_infini * x2_infini
ans1_torch_ref = x1_torch + x2_torch
ans2_torch_ref = x1_torch * x2_torch
print("----------------------------------------")
torch_ans1_result = torch.zeros(shape, dtype=torch.float32, device="cpu")
torch_ans2_result = torch.zeros(shape, dtype=torch.float32, device="cpu")
torch_ans1 = infinicore.from_blob(
torch_ans1_result.data_ptr(),
shape,
dtype=infinicore.float32,
device=infinicore.device("cpu", 0),
)
torch_ans2 = infinicore.from_blob(
torch_ans2_result.data_ptr(),
shape,
dtype=infinicore.float32,
device=infinicore.device("cpu", 0),
)
torch_ans1.copy_(ans1_infini)
torch_ans2.copy_(ans2_infini)
print("----------------------------------------")
print("abs error: ", torch.abs(ans1_torch_ref - torch_ans1_result).max())
print("abs error: ", torch.abs(ans2_torch_ref - torch_ans2_result).max())
if __name__ == "__main__":
test()
# test()
test2()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment