Unverified Commit 9588109d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix shape of new quantized tensor in `make_like` (#1515)



* Fix quantized tensor shape
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* add shape to make_like; add test for chunk
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix typo from suggestion
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 97344d66
...@@ -161,6 +161,36 @@ class TestFloat8Tensor: ...@@ -161,6 +161,36 @@ class TestFloat8Tensor:
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
@pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]])
def test_chunk_op(
self,
dims: DimsType,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test for ops for which shape of inputs and outputs differ."""
# Initialize random data
dims = _to_list(dims)
x_ref = torch.randn(dims, dtype=dtype, device="cpu")
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0)
# Get chunks.
chunk1, chunk2 = x_fp8.chunk(2, dim=0)
# Test chunks.
torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0)
torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0)
# Check shapes.
assert (
chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
), "Wrong shape for chunk1"
assert (
chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
), "Wrong shape for chunk2"
def test_inplace_ops( def test_inplace_ops(
self, self,
dims: DimsType = 23, dims: DimsType = 23,
......
...@@ -402,7 +402,10 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -402,7 +402,10 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
[data] + list(args[1:]), [data] + list(args[1:]),
kwargs, kwargs,
) )
return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] return [
Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape)
for split_tensor in func_out
]
if func == aten.new_zeros.default: if func == aten.new_zeros.default:
tensor = args[0] tensor = args[0]
data = tensor._data data = tensor._data
...@@ -412,7 +415,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -412,7 +415,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
[data] + list(args[1:]), [data] + list(args[1:]),
kwargs, kwargs,
) )
return Float8Tensor.make_like(tensor, data=func_out) return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
if func == torch.ops.aten.as_strided.default: if func == torch.ops.aten.as_strided.default:
tensor = args[0] tensor = args[0]
data = tensor._data data = tensor._data
...@@ -422,7 +425,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor): ...@@ -422,7 +425,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
[data] + list(args[1:]), [data] + list(args[1:]),
kwargs, kwargs,
) )
return Float8Tensor.make_like(tensor, data=func_out) return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
if func == torch.ops.aten.detach.default: if func == torch.ops.aten.detach.default:
return cls.detach(args[0]) return cls.detach(args[0])
if func == torch.ops.aten.clone.default: if func == torch.ops.aten.clone.default:
......
...@@ -433,7 +433,8 @@ class QuantizedTensor(torch.Tensor): ...@@ -433,7 +433,8 @@ class QuantizedTensor(torch.Tensor):
data. data.
""" """
shape = shape if shape is not None else tensor.shape if shape is None:
shape = data.shape if data is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata() kwargs = tensor.get_metadata()
if data is not None: if data is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment