Unverified Commit 69636a08 authored by Victor Oliveira's avatar Victor Oliveira Committed by GitHub
Browse files

ONNX: Fix FP8 quantization for the second MLP in LayerNormMLP (#2577)



ONNX: Fix FP8 quantization for the second MLP in LayernormMLP
Signed-off-by: default avatarVictor Oliveira <victor.oliveira@getcruise.com>
parent fe8fad59
...@@ -2243,14 +2243,23 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2243,14 +2243,23 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export"
assert_warmed_up(self) assert_warmed_up(self)
# Get quantizers
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
_,
_,
_,
_,
fc2_input_quantizer, fc2_input_quantizer,
fc2_weight_quantizer, fc2_weight_quantizer,
output_quantizer, fc2_output_quantizer,
*_, _,
_,
_,
) = self._get_quantizers(False, is_grad_enabled) ) = self._get_quantizers(False, is_grad_enabled)
inp_dtype = inp.dtype inp_dtype = inp.dtype
fc1_weight, fc2_weight = self._get_weight_tensors() fc1_weight, fc2_weight = self._get_weight_tensors()
...@@ -2324,7 +2333,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2324,7 +2333,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias) fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias)
if output_quantizer is not None: if fc2_output_quantizer is not None:
raise NotImplementedError("ONNX export of quantized output is not supported") raise NotImplementedError("ONNX export of quantized output is not supported")
if self.return_layernorm_output: if self.return_layernorm_output:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment