Unverified Commit 93e7d887 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/492: 修复 infinicore.Tensor.dtype 和 infinicore.Tensor.device 返回类型的问题

parent 9a05446f
......@@ -20,7 +20,7 @@ public:
MOORE = INFINI_DEVICE_MOORE,
ILUVATAR = INFINI_DEVICE_ILUVATAR,
KUNLUN = INFINI_DEVICE_KUNLUN,
SUGON = INFINI_DEVICE_SUGON,
HYGON = INFINI_DEVICE_HYGON,
COUNT = INFINI_DEVICE_TYPE_COUNT,
};
......
......@@ -9,6 +9,7 @@ from infinicore.dtype import (
complex64,
complex128,
double,
dtype,
float,
float16,
float32,
......@@ -37,6 +38,7 @@ from infinicore.tensor import (
__all__ = [
# Classes.
"device",
"dtype",
# Data Types.
"bfloat16",
"bool",
......
......@@ -66,6 +66,23 @@ class device:
index -= 1
@staticmethod
def _from_infinicore_device(infinicore_device):
type = _TORCH_DEVICE_MAP[infinicore_device.type]
base_index = 0
for infinicore_type, torch_type in _TORCH_DEVICE_MAP.items():
if torch_type != type:
continue
if infinicore_type == infinicore_device.type:
break
base_index += _infinicore.get_device_count(infinicore_device)
return device(type, base_index + infinicore_device.index)
_TORCH_DEVICE_MAP = {
_infinicore.Device.Type.CPU: "cpu",
......@@ -76,5 +93,5 @@ _TORCH_DEVICE_MAP = {
_infinicore.Device.Type.MOORE: "musa",
_infinicore.Device.Type.ILUVATAR: "cuda",
_infinicore.Device.Type.KUNLUN: "cuda",
_infinicore.Device.Type.SUGON: "cuda",
_infinicore.Device.Type.HYGON: "cuda",
}
import infinicore.device
import infinicore.dtype
from . import _infinicore
class Tensor:
def __init__(self, tensor):
def __init__(self, underlying):
"""An internal method. Please do not use this directly."""
self._underlying = tensor
self._underlying = underlying
self._dtype = infinicore.dtype(self._underlying.dtype)
self._device = infinicore.device._from_infinicore_device(
self._underlying.device
)
@property
def shape(self):
......@@ -13,11 +22,11 @@ class Tensor:
@property
def dtype(self):
return self._underlying.dtype
return self._dtype
@property
def device(self):
return self._underlying.device
return self._device
@property
def ndim(self):
......
......@@ -37,8 +37,8 @@ std::string Device::toString(const Type &type) {
return "ILUVATAR";
case Type::KUNLUN:
return "KUNLUN";
case Type::SUGON:
return "SUGON";
case Type::HYGON:
return "HYGON";
}
// TODO: Add error handling.
......
......@@ -20,7 +20,7 @@ inline void bind(py::module &m) {
.value("MOORE", Device::Type::MOORE)
.value("ILUVATAR", Device::Type::ILUVATAR)
.value("KUNLUN", Device::Type::KUNLUN)
.value("SUGON", Device::Type::SUGON)
.value("HYGON", Device::Type::HYGON)
.value("COUNT", Device::Type::COUNT);
device
......
......@@ -15,6 +15,7 @@ inline void bind(py::module &m) {
.def_property_readonly("strides", [](const Tensor &tensor) { return tensor->strides(); })
.def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); })
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
.def("data_ptr", [](const Tensor &tensor) { return tensor->data(); })
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
......
......@@ -7,7 +7,7 @@ class InfiniDeviceEnum:
MOORE = 5
ILUVATAR = 6
KUNLUN = 7
SUGON = 8
HYGON = 8
InfiniDeviceNames = {
......@@ -19,7 +19,7 @@ InfiniDeviceNames = {
InfiniDeviceEnum.MOORE: "Moore",
InfiniDeviceEnum.ILUVATAR: "Iluvatar",
InfiniDeviceEnum.KUNLUN: "Kunlun",
InfiniDeviceEnum.SUGON: "Sugon",
InfiniDeviceEnum.HYGON: "Hygon",
}
# Mapping that maps InfiniDeviceEnum to torch device string
......@@ -32,5 +32,5 @@ torch_device_map = {
InfiniDeviceEnum.MOORE: "musa",
InfiniDeviceEnum.ILUVATAR: "cuda",
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.SUGON: "cuda",
InfiniDeviceEnum.HYGON: "cuda",
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment