Unverified Commit a8f0fe03 authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise scaling linear quantization recipe (#1559)



* Add GEMM logic for blockwise quantized tensors.

GEMM test cases included in pytorch integration.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update NVTE_BLOCK_SCALING for GEMM.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gate feature on CUDA 12.9
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Gemm typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove unecessary type converter change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reflect epilogue availability and test supported epilogues.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* GEMM simplifications from recipe branch.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update GEMM DGelu tests to match support depending on output dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Force pow2Scales in GEMM
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add GEMM test to pytorch test suite.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update import for GEMM test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add license.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test gemm supported predicate.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use sgemm like interfaces and naming.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rewrite GEMM comment.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Recipe setup for Linear modules.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use 12.9 feature test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Run against tensor dumps from internal library.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update FIXME to TODO with linked issue.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update full recompute feature to save recipe.

The recompute context uses the same recipe
and fp8 settings as the original fwd pass.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR Feedback. Avoid reusing quantizer objects.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update logic in module.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Format py.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update for PP bug.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test numerics.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update force_power_of_2 scales in the recipe.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update usage method to satisfy upstream changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* fix subchannel recipe in distributed test with bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Edit and cleanup BF16 gather code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test import.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* support columnwise only mode to 1D quantize kernel
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format and move enum
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip alloc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* try async bf16 gather
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format python code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Document and type code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytorch lint errors.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Dont set high precision dtype.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add test for sanity and CG; fix CG for sequential?
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Keep make_quantizers API stable

Update num_quantizers instead to pass cuda_graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix import name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Rename recipe method.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Skip grouped linear sanity test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Set usage before BF16 gather.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* refactor for nvte_quantize_v2
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* Format code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup nvte_quantize_v2
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Test fp32 scales.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Disable CUDA graph.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Simplify layernorm linear
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Cleanup layernorm linear.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* LayerNorm linear bwd gather logic.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Communication updates.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update transformer_engine/pytorch/ops/op.py

Apply MR comment change.
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>

* Lint fix.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Enable cuda graph tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reduce chance of spurious failure and reword.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Review suggestions from @timmoon10
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update CPP tests.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update common.h
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* Update test_float8blockwisetensor.py
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarkwyss-nvidia <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Co-authored-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarXin Yao <yaox12@outlook.com>
parent 0da60449
...@@ -44,7 +44,6 @@ class Float8BlockQuantizer(Quantizer): ...@@ -44,7 +44,6 @@ class Float8BlockQuantizer(Quantizer):
block_scaling_dim: int = 2, block_scaling_dim: int = 2,
) -> None: ) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise) super().__init__(rowwise=rowwise, columnwise=columnwise)
assert rowwise
self.dtype = fp8_dtype self.dtype = fp8_dtype
self.block_len = 128 self.block_len = 128
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
...@@ -168,6 +167,11 @@ class Float8BlockQuantizer(Quantizer): ...@@ -168,6 +167,11 @@ class Float8BlockQuantizer(Quantizer):
colwise_shape.append(shape[i]) colwise_shape.append(shape[i])
return tuple(colwise_shape) return tuple(colwise_shape)
# TODO(kwyss): With FP8 gather support, we need to implement a
# shape/layout/swizzle check to know whether FP8 gather works
# cleanly by stacking data without aliasing tiles and whether
# the scales also stack on the proper dimensions.
def make_empty( def make_empty(
self, self,
shape: Iterable[int], shape: Iterable[int],
...@@ -181,13 +185,16 @@ class Float8BlockQuantizer(Quantizer): ...@@ -181,13 +185,16 @@ class Float8BlockQuantizer(Quantizer):
device = torch.device("cuda") device = torch.device("cuda")
# Allocate FP8 data # Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device) data = None
scale_shape = self.get_scale_shape(shape, columnwise=False) scale_inv = None
scale_inv = torch.empty( if self.rowwise_usage:
scale_shape, data = torch.empty(shape, dtype=torch.uint8, device=device)
dtype=torch.float32, scale_shape = self.get_scale_shape(shape, columnwise=False)
device=device, scale_inv = torch.empty(
) scale_shape,
dtype=torch.float32,
device=device,
)
# Allocate FP8 data transpose if needed # Allocate FP8 data transpose if needed
columnwise_data = None columnwise_data = None
...@@ -489,7 +496,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): ...@@ -489,7 +496,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
dst._fp8_dtype = src._fp8_dtype dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv
dst.dtype = src.dtype
# Check that tensor dimensions match # Check that tensor dimensions match
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment