Unverified Commit 226688bd authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix][VLM] Make apply_fp8_linear work with >2D input (#9812)

parent 64cb1cdc
...@@ -96,21 +96,26 @@ def apply_fp8_linear( ...@@ -96,21 +96,26 @@ def apply_fp8_linear(
# If dynamic, layer.input_scale is None and x_scale computed from x. # If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale. # If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported: if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = ops.scaled_fp8_quant(
input, input_2d,
input_scale, input_scale,
scale_ub=input_scale_ub, scale_ub=input_scale_ub,
use_per_token_if_dynamic=use_per_token_if_dynamic) use_per_token_if_dynamic=use_per_token_if_dynamic)
# Fused GEMM_DQ # Fused GEMM_DQ
return ops.cutlass_scaled_mm(qinput, output = ops.cutlass_scaled_mm(qinput,
weight, weight,
out_dtype=input.dtype, out_dtype=input.dtype,
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
bias=bias) bias=bias)
return output.view(*output_shape)
# torch.scaled_mm supports per tensor weights + activations only # torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token # so fallback to naive if per channel or per token
...@@ -119,7 +124,7 @@ def apply_fp8_linear( ...@@ -119,7 +124,7 @@ def apply_fp8_linear(
# for matrices with batch dimension > 16. # for matrices with batch dimension > 16.
# This could change in the future. # This could change in the future.
qinput, x_scale = ops.scaled_fp8_quant( qinput, x_scale = ops.scaled_fp8_quant(
input, input_2d,
input_scale, input_scale,
num_token_padding=17, num_token_padding=17,
use_per_token_if_dynamic=use_per_token_if_dynamic) use_per_token_if_dynamic=use_per_token_if_dynamic)
...@@ -138,8 +143,10 @@ def apply_fp8_linear( ...@@ -138,8 +143,10 @@ def apply_fp8_linear(
# A fix for discrepancy in scaled_mm which returns tuple # A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5 # for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
return torch.narrow(output[0], 0, 0, input.shape[0]) output = output[0]
return torch.narrow(output, 0, 0, input.shape[0])
return torch.narrow(output, 0, 0,
input_2d.shape[0]).view(*output_shape)
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
...@@ -176,15 +183,15 @@ def apply_fp8_linear( ...@@ -176,15 +183,15 @@ def apply_fp8_linear(
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
output = output[0] output = output[0]
# Unpad (undo num_token_padding) # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0]) output = torch.narrow(output, 0, 0, input_2d.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
# DQ # DQ
# C = sw * sx * (X * W) + bias # C = sw * sx * (X * W) + bias
output = output * x_scale * weight_scale.t() output = output * x_scale * weight_scale.t()
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.to(dtype=input.dtype) return output.to(dtype=input.dtype).view(*output_shape)
def apply_int8_linear( def apply_int8_linear(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment