Unverified Commit 4a147e0f authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update fp8_meta amax when copying into Float8Tensor (#567)



* Update fp8_meta amax when copying into Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid amax when copying between Float8Tensors with fp8_metas
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9020344b
......@@ -562,6 +562,22 @@ class Float8Tensor(torch.Tensor):
if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone()
if dst._fp8_meta is not None:
if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax()
src_amax = torch.maximum(-src_min, src_max)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=src._fp8_meta_forward,
)
fp8_meta_index = src._fp8_meta_index
src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta_index = dst._fp8_meta_index
dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
torch.maximum(src_amax, dst_amax, out=dst_amax)
else:
dst.copy_(src.from_float8())
......@@ -582,11 +598,14 @@ class Float8Tensor(torch.Tensor):
# Update scaling factor if FP8 meta tensors are available
if dst._fp8_meta is None:
scale = dst._scale_inv.reciprocal()
amax = torch.empty_like(scale)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index]
fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal()
# Cast to FP8
......@@ -596,7 +615,7 @@ class Float8Tensor(torch.Tensor):
src.view(1,-1),
scale,
dst._data.view(1,-1),
torch.empty_like(dst._scale_inv), # amax
amax,
dst._scale_inv,
dst._fp8_dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment