Unverified Commit 4f1d70fb authored by LadyRick's avatar LadyRick Committed by GitHub
Browse files

fix amax -> abs max in fp8_calibration (#534)



[PyTorch] fix amax calculate during fp8 calibration
Signed-off-by: default avatarladyrick <ladyrick@qq.com>
parent 4d444db1
...@@ -220,11 +220,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -220,11 +220,13 @@ class _LayerNormLinear(torch.autograd.Function):
if fp8_calibration: if fp8_calibration:
# amax of input # amax of input
amin, amax = ln_out_total.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float() torch.max(-amin, amax).float()
# amax of weight # amax of weight
amin, amax = weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float() torch.max(-amin, amax).float()
out, _, _ = tex.gemm( out, _, _ = tex.gemm(
weight, weight,
......
...@@ -345,11 +345,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -345,11 +345,13 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8_calibration: if fp8_calibration:
# amax of fc1 input # amax of fc1 input
amin, amax = ln_out_total.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float() torch.max(-amin, amax).float()
# amax of fc1 weight # amax of fc1 weight
amin, amax = fc1_weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(fc1_weight).float() torch.max(-amin, amax).float()
fc1_outputs = tex.gemm( fc1_outputs = tex.gemm(
fc1_weight, fc1_weight,
...@@ -383,11 +385,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -383,11 +385,13 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8_calibration: if fp8_calibration:
# amax of fc2 input # amax of fc2 input
amin, amax = gelu_out.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
torch.amax(gelu_out).float() torch.max(-amin, amax).float()
# amax of fc2 weight # amax of fc2 weight
amin, amax = fc2_weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.amax(fc2_weight).float() torch.max(-amin, amax).float()
if ub_split_rs: if ub_split_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
......
...@@ -226,11 +226,13 @@ class _Linear(torch.autograd.Function): ...@@ -226,11 +226,13 @@ class _Linear(torch.autograd.Function):
if fp8_calibration: if fp8_calibration:
# amax of input # amax of input
amin, amax = inputmat_total.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(inputmat_total).float() torch.max(-amin, amax).float()
# amax of weight # amax of weight
amin, amax = weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float() torch.max(-amin, amax).float()
if ub_split_rs: if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop") ub_obj_projout = get_ub("proj_fprop")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment