Unverified Commit c8fd97f2 authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Kernel] Use CUTLASS kernels for the FP8 layers with Bias (#6270)

parent 94b82e8c
...@@ -112,7 +112,7 @@ def apply_fp8_linear( ...@@ -112,7 +112,7 @@ def apply_fp8_linear(
# If dynamic, layer.input_scale is None and x_scale computed from x. # If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale. # If static, layer.input_scale is scalar and x_scale is input_scale.
if bias is None and cutlass_fp8_supported: if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale) qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
# Fused GEMM_DQ # Fused GEMM_DQ
...@@ -120,7 +120,8 @@ def apply_fp8_linear( ...@@ -120,7 +120,8 @@ def apply_fp8_linear(
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale) scale_b=weight_scale,
bias=bias)
else: else:
qinput, x_scale = ops.scaled_fp8_quant(input, qinput, x_scale = ops.scaled_fp8_quant(input,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment