Unverified Commit f8693d2b authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

Fix CI failure related to bug in MXFP8 copy implementation (#2369)



* fix ci issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert back testing changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 29537c96
...@@ -339,23 +339,21 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -339,23 +339,21 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if func == torch.ops.aten.copy_.default: if func == torch.ops.aten.copy_.default:
dst, src = args[0], args[1] dst, src = args[0], args[1]
# Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor):
# If not, default to base class behavior. # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages.
rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None # If not, default to base class behavior.
columnwise_matches = src._columnwise_data is not None or dst._columnwise_data is None rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None
if ( columnwise_matches = (
isinstance(src, MXFP8Tensor) src._columnwise_data is not None or dst._columnwise_data is None
and isinstance(dst, MXFP8Tensor) )
and rowwise_matches if rowwise_matches and columnwise_matches:
and columnwise_matches if dst._rowwise_data is not None:
): dst._rowwise_data.copy_(src._rowwise_data.detach())
if dst._rowwise_data is not None: dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach())
dst._rowwise_data.copy_(src._rowwise_data.detach()) if dst._columnwise_data is not None:
dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) dst._columnwise_data.copy_(src._columnwise_data.detach())
if dst._columnwise_data is not None: dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
dst._columnwise_data.copy_(src._columnwise_data.detach()) return dst
dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach())
return dst
# FSDP2 related functions. # FSDP2 related functions.
if func == aten.split.Tensor: if func == aten.split.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment