Unverified Commit 7042d7ae authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

TE Gemma tutorial attempt#2 (#1839)



* add tutorial files and other local changes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove extraneous code for easy debu
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make cuda graphs work with non-paged and paged attention
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* perf imp for kv cache ops
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add code for calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* optimize kv_cache reindex and copy kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changes to make quantizers work with fp8_calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* avoid reindexing from python side
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename variable from previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use quantizer only if needed
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* functionality of the tutorial tested and perf checked
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove files and update headers/licenses
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* update header/license
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update tutorial for review
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make weights downloadable on the fly; remove extra print statements
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and update comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add comma back, typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* sequence_start_positions should be None for training
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add paged attention numberes and update requirements.txt file
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make tutorial work on blackwell
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove gemma FT tutorial for now
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixing the headings placement and rewording attention -> kv caching
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixes from comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the images
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* misc fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add more comments to te_gemma.py and cleanup utils.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more information about the hierarchy of the classes used in the tutorial
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add better cuda graphs picture
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* addd updated cuda graphs pictures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add illustrated cuda graphs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fixes in documentation
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add torch.no_grad() to force reduced memory usage
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* some fixes from recent comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes from remaining comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add te_rope_emb to class desc
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix tutorial wording; add calibration fix to grouped_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent ba37529c
......@@ -1767,7 +1767,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
......
......@@ -444,14 +444,19 @@ class _LayerNormMLP(torch.autograd.Function):
# tex.quantize does not support GELU fusion for blockwise.
act_out = activation_func(fc1_out, None)
act_out = tex.quantize(act_out, fc2_input_quantizer)
else:
if fp8_calibration:
act_out = activation_func(fc1_out, None)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
if fp8_calibration:
if not fp8 and fp8_calibration:
if fc2_input_quantizer is not None:
fc2_input_quantizer.calibrate(act_out)
if fc2_weight_quantizer is not None:
fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
......@@ -1897,7 +1902,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_grad_output_quantizer,
) = [None] * 10
fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers()
if self.fp8:
if self.fp8 or self.fp8_calibration:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
......@@ -2114,7 +2119,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None, None]
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
......
......@@ -1643,7 +1643,7 @@ class Linear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None]
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment