Unverified Commit f8693d2b authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

Fix CI failure related to bug in MXFP8 copy implementation (#2369)



* fix ci issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert back testing changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 29537c96
...@@ -339,16 +339,14 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -339,16 +339,14 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1] dst, src = args[0], args[1]
if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
# If not, default to base class behavior. # If not, default to base class behavior.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
columnwise_matches = src._columnwise_data is not None or dst._columnwise_data is None columnwise_matches = (
if ( src._columnwise_data is not None or dst._columnwise_data is None
isinstance(src, MXFP8Tensor) )
and isinstance(dst, MXFP8Tensor) if rowwise_matches and columnwise_matches:
and rowwise_matches
and columnwise_matches
):
if dst._rowwise_data is not None: if dst._rowwise_data is not None:
dst._rowwise_data.copy_(src._rowwise_data.detach()) dst._rowwise_data.copy_(src._rowwise_data.detach())
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment