Unverified Commit 81c363bf authored by xiaoxi-wangfj's avatar xiaoxi-wangfj Committed by GitHub
Browse files

[PyTorch] Add record_stream and untyped_storage func op in QuantizedTensor (#2144)



* [PyTorch] Add record_stream and untyped_storage func op in QuantizedTensor
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

* Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>

---------
Signed-off-by: default avatarxiaoxi-wangfj <690912414@qq.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 452c7374
...@@ -403,6 +403,21 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -403,6 +403,21 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape) return _ReshapeFunc.apply(self, shape)
def untyped_storage(self) -> torch.UntypedStorage:
"""Return the underlying UntypedStorage of the FP8 data.
Note that FP8 block-scaled tensor may involve multiple
buffers: row-wise FP8 data, row-wise scales, column-wise FP8
data, column-wise scales. The UntypedStorage of the row-wise
FP8 data is returned if it exists, and otherwise the
UntypedStorage of the column-wise FP8 data.
"""
data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data
if data is not None:
return data.untyped_storage()
return torch.UntypedStorage(0, device=self.device)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None): def __torch_dispatch__(cls, func, types, args, kwargs=None):
...@@ -427,6 +442,19 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -427,6 +442,19 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
) )
return Float8BlockwiseQTensor.make_like(tensor) return Float8BlockwiseQTensor.make_like(tensor)
# record stream op
if func == torch.ops.aten.record_stream.default:
qt, stream = args
for t in (
qt._rowwise_data,
qt._columnwise_data,
qt._rowwise_scale_inv,
qt._columnwise_scale_inv,
):
if t is not None and t.is_cuda:
t.record_stream(stream)
return None
# Default case # Default case
return super().__torch_dispatch__(func, types, args, kwargs) return super().__torch_dispatch__(func, types, args, kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment