Commit 2d393626 authored by zhuyue's avatar zhuyue Committed by zhuyue
Browse files

issue/618 - Add __eq__ to device and narrow method to Tensor.

parent 74934cdf
...@@ -35,6 +35,20 @@ class device: ...@@ -35,6 +35,20 @@ class device:
def __str__(self): def __str__(self):
return f"{self.type}{f':{self.index}' if self.index is not None else ''}" return f"{self.type}{f':{self.index}' if self.index is not None else ''}"
def __eq__(self, other):
"""
Compare two device objects for equality.
Args:
other: The object to compare with
Returns:
bool: True if both objects are device instances with the same type and index
"""
if not isinstance(other, device):
return False
return self.type == other.type and self.index == other.index
@staticmethod @staticmethod
def _to_infinicore_device(type, index): def _to_infinicore_device(type, index):
all_device_types = tuple(_infinicore.Device.Type.__members__.values())[:-1] all_device_types = tuple(_infinicore.Device.Type.__members__.values())[:-1]
......
...@@ -96,6 +96,9 @@ class Tensor: ...@@ -96,6 +96,9 @@ class Tensor:
def __mul__(self, other): def __mul__(self, other):
return infinicore.mul(self, other) return infinicore.mul(self, other)
def narrow(self, dim, start, length):
return infinicore.narrow(self, dim, start, length)
def empty(size, *, dtype=None, device=None, pin_memory=False): def empty(size, *, dtype=None, device=None, pin_memory=False):
return Tensor( return Tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment