Unverified Commit 4a147e0f authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Update fp8_meta amax when copying into Float8Tensor (#567)



* Update fp8_meta amax when copying into Float8Tensor
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid amax when copying between Float8Tensors with fp8_metas
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9020344b
...@@ -562,6 +562,22 @@ class Float8Tensor(torch.Tensor): ...@@ -562,6 +562,22 @@ class Float8Tensor(torch.Tensor):
if dst._fp8_dtype == src._fp8_dtype: if dst._fp8_dtype == src._fp8_dtype:
dst._data.copy_(src._data) dst._data.copy_(src._data)
dst._scale_inv = src._scale_inv.clone() dst._scale_inv = src._scale_inv.clone()
if dst._fp8_meta is not None:
if src._fp8_meta is None:
src_min, src_max = src.from_float8().aminmax()
src_amax = torch.maximum(-src_min, src_max)
else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=src._fp8_meta_forward,
)
fp8_meta_index = src._fp8_meta_index
src_amax = src._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward,
)
fp8_meta_index = dst._fp8_meta_index
dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
torch.maximum(src_amax, dst_amax, out=dst_amax)
else: else:
dst.copy_(src.from_float8()) dst.copy_(src.from_float8())
...@@ -582,11 +598,14 @@ class Float8Tensor(torch.Tensor): ...@@ -582,11 +598,14 @@ class Float8Tensor(torch.Tensor):
# Update scaling factor if FP8 meta tensors are available # Update scaling factor if FP8 meta tensors are available
if dst._fp8_meta is None: if dst._fp8_meta is None:
scale = dst._scale_inv.reciprocal() scale = dst._scale_inv.reciprocal()
amax = torch.empty_like(scale)
else: else:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dst._fp8_meta_forward, forward=dst._fp8_meta_forward,
) )
scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] fp8_meta_index = dst._fp8_meta_index
scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index]
amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index]
dst._scale_inv = scale.detach().view(1).reciprocal() dst._scale_inv = scale.detach().view(1).reciprocal()
# Cast to FP8 # Cast to FP8
...@@ -596,7 +615,7 @@ class Float8Tensor(torch.Tensor): ...@@ -596,7 +615,7 @@ class Float8Tensor(torch.Tensor):
src.view(1,-1), src.view(1,-1),
scale, scale,
dst._data.view(1,-1), dst._data.view(1,-1),
torch.empty_like(dst._scale_inv), # amax amax,
dst._scale_inv, dst._scale_inv,
dst._fp8_dtype, dst._fp8_dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment