[JAX] Quickstart documentation (#2310)
* jax quickstart guide first commit Signed-off-by:tdophung <tdophung@nvidia.com> * edit the syntax errors and remove unnecessary comments in utils. Add some footnotes in the quick start notebook Signed-off-by:
tdophung <tdophung@nvidia.com> * Fix greptiles comments on spelling, deepcopy, vjp function signature comaptibility with speedometer Signed-off-by:
tdophung <tdophung@nvidia.com> * Add Copyright to utils and fix some more greptiles complaints Signed-off-by:
tdophung <tdophung@nvidia.com> * Add comments to alternative of layers Signed-off-by:
tdophung <tdophung@nvidia.com> * Remove weight sharing between different iterations of the transformerLayer Signed-off-by:
tdophung <tdophung@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by:
tdophung <tdophung@nvidia.com> * Add enum for attention implementations. Fix inconsistency between fuse and unfused TE impls to achieve same performance (removing extra dropout layer in fused layers. Also some minor wording changes Signed-off-by:
tdophung <tdophung@nvidia.com> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug in TransformerLayer expected input shape being [sequence, batch, ...] instead of [batch, sequence,...] Signed-off-by:
tdophung <tdophung@nvidia.com> * Changing structure of notebook to bring fp8 ahead of fuse, to allow for fuse to take effect because quantization exist as suggested. Also make TransformerLayer perf get closer to Fused by setting hidden_dropout=0 Signed-off-by:
tdophung <tdophung@nvidia.com> * add option to choose between different attention implementation in call of BasicTETransformerLayer and demonstrated difference in runtime between using flax and using te's attetion implementation Signed-off-by:
tdophung <tdophung@nvidia.com> * Fix mistake in lacking attention_implementation in FuseTETransformerLayer Signed-off-by:
tdophung <tdophung@nvidia.com> * Removing AttentionWrapper and custom built DPA, using flax and TE's impl only, removing last mention of Pytorch Signed-off-by:
tdophung <tdophung@nvidia.com> * More changing to markdowns to remove pytorch Signed-off-by:
tdophung <tdophung@nvidia.com> * cosmetics fixes Signed-off-by:
tdophung <tdophung@nvidia.com> * changing names of all implementations Signed-off-by:
tdophung <tdophung@nvidia.com> * change fp8_autocast to autocast, make causal mask, and some wording changes Signed-off-by:
tdophung <tdophung@nvidia.com> --------- Signed-off-by:
tdophung <tdophung@nvidia.com> Co-authored-by:
tdophung <tdophung@dc2-container-xterm-034.prd.it.nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Showing
This diff is collapsed.
Please register or sign in to comment