Unverified Commit 62f5c9ee authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Use 1x quantization + jax transpose for performance for tensor-scaling (#1830)



* Use 1x quantization + jax transpose on BW for performance
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use 1x quantization on Hopper as well as it is also faster
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Undo architecture check helper function
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent c6a9e261
...@@ -629,6 +629,13 @@ def _quantize_dbias_impl( ...@@ -629,6 +629,13 @@ def _quantize_dbias_impl(
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
# It is faster to use 1x quantization for tensor scaling
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
q_layout = quantizer.q_layout
if force_1x_quantization:
q_layout = QuantizeLayout.ROWWISE
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -641,7 +648,7 @@ def _quantize_dbias_impl( ...@@ -641,7 +648,7 @@ def _quantize_dbias_impl(
scale, scale,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_layout=quantizer.q_layout.value, q_layout=q_layout.value,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -651,6 +658,15 @@ def _quantize_dbias_impl( ...@@ -651,6 +658,15 @@ def _quantize_dbias_impl(
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
if q_layout == QuantizeLayout.ROWWISE:
# Quantizer requires 2x quantization, but we are using 1x quantization
# for performance reasons, so we need to generate the colwise data in JAX
if flatten_axis < 0:
flatten_axis += x.ndim
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis))
)
quantizer.update(updated_amax) quantizer.update(updated_amax)
out = ScaledTensorFactory.create( out = ScaledTensorFactory.create(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment