Unverified Commit 4f1d70fb authored by LadyRick's avatar LadyRick Committed by GitHub
Browse files

fix amax -> abs max in fp8_calibration (#534)



[PyTorch] fix amax calculate during fp8 calibration
Signed-off-by: default avatarladyrick <ladyrick@qq.com>
parent 4d444db1
......@@ -220,11 +220,13 @@ class _LayerNormLinear(torch.autograd.Function):
if fp8_calibration:
# amax of input
amin, amax = ln_out_total.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
torch.max(-amin, amax).float()
# amax of weight
amin, amax = weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
torch.max(-amin, amax).float()
out, _, _ = tex.gemm(
weight,
......
......@@ -345,11 +345,13 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8_calibration:
# amax of fc1 input
amin, amax = ln_out_total.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(ln_out_total).float()
torch.max(-amin, amax).float()
# amax of fc1 weight
amin, amax = fc1_weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(fc1_weight).float()
torch.max(-amin, amax).float()
fc1_outputs = tex.gemm(
fc1_weight,
......@@ -383,11 +385,13 @@ class _LayerNormMLP(torch.autograd.Function):
if fp8_calibration:
# amax of fc2 input
amin, amax = gelu_out.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
torch.amax(gelu_out).float()
torch.max(-amin, amax).float()
# amax of fc2 weight
amin, amax = fc2_weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.amax(fc2_weight).float()
torch.max(-amin, amax).float()
if ub_split_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
......
......@@ -226,11 +226,13 @@ class _Linear(torch.autograd.Function):
if fp8_calibration:
# amax of input
amin, amax = inputmat_total.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
torch.amax(inputmat_total).float()
torch.max(-amin, amax).float()
# amax of weight
amin, amax = weight.aminmax()
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
torch.max(-amin, amax).float()
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment