Commit a9e430c0 authored by yuguo's avatar yuguo
Browse files
parents 9ab6cd98 a397dcb7
......@@ -801,7 +801,7 @@ class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBas
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
y_error=0.98 if int8_simulation_fp8 else 0.5,
ln_out_error=0.5,
dgrad_error=1,
wgrad_error=1,
......
......@@ -190,13 +190,20 @@ use_split_accumulator = False
# fp8 to int8
quantizer = Float8CurrentScalingQuantizer(
quantizer_e5m2 = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
force_pow_2_scales=False,
amax_epsilon=0.0,
)
quantizer_e4m3 = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device="cuda",
force_pow_2_scales=False,
amax_epsilon=0.0,
)
# current scaling
def to_float8_CS(
tensor: torch.Tensor,
......@@ -206,6 +213,7 @@ def to_float8_CS(
amax_epsilon: float = 0.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
quantizer = quantizer_e5m2 if fp8_dtype == tex.DType.kFloat8E5M2 else quantizer_e4m3
if return_transpose:
quantizer.set_usage(rowwise=True, columnwise=True)
else:
......@@ -235,13 +243,13 @@ end = time.time()
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e5m2))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e5m2))
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
......@@ -279,10 +287,10 @@ print("output: ", output)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
y_int32 = tex.generic_gemm(
w_int8,
transa,
......@@ -376,7 +384,7 @@ torch.cuda.synchronize()
start = time.time()
for i in range(20):
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
......@@ -473,8 +481,8 @@ start = time.time()
for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment