Unverified Commit 2ad5da95 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix incorrect docstrings in tensor saving functions (#1549)



Fix incorrect docstrings in tensor saving functions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 44c8fd0f
...@@ -98,12 +98,7 @@ class Float8TensorBase: ...@@ -98,12 +98,7 @@ class Float8TensorBase:
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward """Prepare the tensor base for saving for backward"""
After calling this, the tensor instance does not hold any
data.
"""
tensors = [self._data, self._transpose] tensors = [self._data, self._transpose]
return tensors, self return tensors, self
......
...@@ -93,12 +93,7 @@ class MXFP8TensorBase: ...@@ -93,12 +93,7 @@ class MXFP8TensorBase:
} }
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward """Prepare the tensor base for saving for backward"""
After calling this, the tensor instance does not hold any
data.
"""
tensors = [self._rowwise_data, self._columnwise_data] tensors = [self._rowwise_data, self._columnwise_data]
return tensors, self return tensors, self
......
...@@ -349,12 +349,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -349,12 +349,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
self._transpose_invalid = True self._transpose_invalid = True
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward """Prepare the tensor base for saving for backward"""
After calling this, the tensor instance does not hold any
data.
"""
return [self], None return [self], None
@classmethod @classmethod
......
...@@ -286,12 +286,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): ...@@ -286,12 +286,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward """Prepare the tensor base for saving for backward"""
After calling this, the tensor instance does not hold any
data.
"""
return [self], None return [self], None
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment