# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
# TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
# Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.