Unverified Commit 93e7d887 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/492: 修复 infinicore.Tensor.dtype 和 infinicore.Tensor.device 返回类型的问题

parent 9a05446f
...@@ -20,7 +20,7 @@ public: ...@@ -20,7 +20,7 @@ public:
MOORE = INFINI_DEVICE_MOORE, MOORE = INFINI_DEVICE_MOORE,
ILUVATAR = INFINI_DEVICE_ILUVATAR, ILUVATAR = INFINI_DEVICE_ILUVATAR,
KUNLUN = INFINI_DEVICE_KUNLUN, KUNLUN = INFINI_DEVICE_KUNLUN,
SUGON = INFINI_DEVICE_SUGON, HYGON = INFINI_DEVICE_HYGON,
COUNT = INFINI_DEVICE_TYPE_COUNT, COUNT = INFINI_DEVICE_TYPE_COUNT,
}; };
......
...@@ -9,6 +9,7 @@ from infinicore.dtype import ( ...@@ -9,6 +9,7 @@ from infinicore.dtype import (
complex64, complex64,
complex128, complex128,
double, double,
dtype,
float, float,
float16, float16,
float32, float32,
...@@ -37,6 +38,7 @@ from infinicore.tensor import ( ...@@ -37,6 +38,7 @@ from infinicore.tensor import (
__all__ = [ __all__ = [
# Classes. # Classes.
"device", "device",
"dtype",
# Data Types. # Data Types.
"bfloat16", "bfloat16",
"bool", "bool",
......
...@@ -66,6 +66,23 @@ class device: ...@@ -66,6 +66,23 @@ class device:
index -= 1 index -= 1
@staticmethod
def _from_infinicore_device(infinicore_device):
type = _TORCH_DEVICE_MAP[infinicore_device.type]
base_index = 0
for infinicore_type, torch_type in _TORCH_DEVICE_MAP.items():
if torch_type != type:
continue
if infinicore_type == infinicore_device.type:
break
base_index += _infinicore.get_device_count(infinicore_device)
return device(type, base_index + infinicore_device.index)
_TORCH_DEVICE_MAP = { _TORCH_DEVICE_MAP = {
_infinicore.Device.Type.CPU: "cpu", _infinicore.Device.Type.CPU: "cpu",
...@@ -76,5 +93,5 @@ _TORCH_DEVICE_MAP = { ...@@ -76,5 +93,5 @@ _TORCH_DEVICE_MAP = {
_infinicore.Device.Type.MOORE: "musa", _infinicore.Device.Type.MOORE: "musa",
_infinicore.Device.Type.ILUVATAR: "cuda", _infinicore.Device.Type.ILUVATAR: "cuda",
_infinicore.Device.Type.KUNLUN: "cuda", _infinicore.Device.Type.KUNLUN: "cuda",
_infinicore.Device.Type.SUGON: "cuda", _infinicore.Device.Type.HYGON: "cuda",
} }
import infinicore.device
import infinicore.dtype
from . import _infinicore from . import _infinicore
class Tensor: class Tensor:
def __init__(self, tensor): def __init__(self, underlying):
"""An internal method. Please do not use this directly.""" """An internal method. Please do not use this directly."""
self._underlying = tensor self._underlying = underlying
self._dtype = infinicore.dtype(self._underlying.dtype)
self._device = infinicore.device._from_infinicore_device(
self._underlying.device
)
@property @property
def shape(self): def shape(self):
...@@ -13,11 +22,11 @@ class Tensor: ...@@ -13,11 +22,11 @@ class Tensor:
@property @property
def dtype(self): def dtype(self):
return self._underlying.dtype return self._dtype
@property @property
def device(self): def device(self):
return self._underlying.device return self._device
@property @property
def ndim(self): def ndim(self):
......
...@@ -37,8 +37,8 @@ std::string Device::toString(const Type &type) { ...@@ -37,8 +37,8 @@ std::string Device::toString(const Type &type) {
return "ILUVATAR"; return "ILUVATAR";
case Type::KUNLUN: case Type::KUNLUN:
return "KUNLUN"; return "KUNLUN";
case Type::SUGON: case Type::HYGON:
return "SUGON"; return "HYGON";
} }
// TODO: Add error handling. // TODO: Add error handling.
......
...@@ -20,7 +20,7 @@ inline void bind(py::module &m) { ...@@ -20,7 +20,7 @@ inline void bind(py::module &m) {
.value("MOORE", Device::Type::MOORE) .value("MOORE", Device::Type::MOORE)
.value("ILUVATAR", Device::Type::ILUVATAR) .value("ILUVATAR", Device::Type::ILUVATAR)
.value("KUNLUN", Device::Type::KUNLUN) .value("KUNLUN", Device::Type::KUNLUN)
.value("SUGON", Device::Type::SUGON) .value("HYGON", Device::Type::HYGON)
.value("COUNT", Device::Type::COUNT); .value("COUNT", Device::Type::COUNT);
device device
......
...@@ -15,6 +15,7 @@ inline void bind(py::module &m) { ...@@ -15,6 +15,7 @@ inline void bind(py::module &m) {
.def_property_readonly("strides", [](const Tensor &tensor) { return tensor->strides(); }) .def_property_readonly("strides", [](const Tensor &tensor) { return tensor->strides(); })
.def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); }) .def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); })
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); }) .def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
.def("data_ptr", [](const Tensor &tensor) { return tensor->data(); }) .def("data_ptr", [](const Tensor &tensor) { return tensor->data(); })
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); }) .def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
......
...@@ -7,7 +7,7 @@ class InfiniDeviceEnum: ...@@ -7,7 +7,7 @@ class InfiniDeviceEnum:
MOORE = 5 MOORE = 5
ILUVATAR = 6 ILUVATAR = 6
KUNLUN = 7 KUNLUN = 7
SUGON = 8 HYGON = 8
InfiniDeviceNames = { InfiniDeviceNames = {
...@@ -19,7 +19,7 @@ InfiniDeviceNames = { ...@@ -19,7 +19,7 @@ InfiniDeviceNames = {
InfiniDeviceEnum.MOORE: "Moore", InfiniDeviceEnum.MOORE: "Moore",
InfiniDeviceEnum.ILUVATAR: "Iluvatar", InfiniDeviceEnum.ILUVATAR: "Iluvatar",
InfiniDeviceEnum.KUNLUN: "Kunlun", InfiniDeviceEnum.KUNLUN: "Kunlun",
InfiniDeviceEnum.SUGON: "Sugon", InfiniDeviceEnum.HYGON: "Hygon",
} }
# Mapping that maps InfiniDeviceEnum to torch device string # Mapping that maps InfiniDeviceEnum to torch device string
...@@ -32,5 +32,5 @@ torch_device_map = { ...@@ -32,5 +32,5 @@ torch_device_map = {
InfiniDeviceEnum.MOORE: "musa", InfiniDeviceEnum.MOORE: "musa",
InfiniDeviceEnum.ILUVATAR: "cuda", InfiniDeviceEnum.ILUVATAR: "cuda",
InfiniDeviceEnum.KUNLUN: "cuda", InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.SUGON: "cuda", InfiniDeviceEnum.HYGON: "cuda",
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment