Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
{% extends '!footer.html' %}
{% block contentinfo %}
<img src="{{ pathto('_static/NVIDIA-LogoBlack.svg', 1) }}"/>
<p class="notices">
<a href="https://www.nvidia.com/en-us/about-nvidia/privacy-policy/" target="_blank">Privacy Policy</a>
|
<a href="https://www.nvidia.com/en-us/about-nvidia/privacy-center/" target="_blank">Manage My Privacy</a>
|
<a href="https://www.nvidia.com/en-us/preferences/start/" target="_blank">Do Not Sell or Share My Data</a>
|
<a href="https://www.nvidia.com/en-us/about-nvidia/terms-of-service/" target="_blank">Terms of Service</a>
|
<a href="https://www.nvidia.com/en-us/about-nvidia/accessibility/" target="_blank">Accessibility</a>
|
<a href="https://www.nvidia.com/en-us/about-nvidia/company-policies/" target="_blank">Corporate Policies</a>
|
<a href="https://www.nvidia.com/en-us/product-security/" target="_blank">Product Security</a>
|
<a href="https://www.nvidia.com/en-us/contact/" target="_blank">Contact</a>
</p>
{{ super() }}
{% endblock %}
{% extends "!layout.html" %}
{% block extrahead %}
<script src="https://assets.adobedtm.com/5d4962a43b79/c1061d2c5e7b/launch-191c2462b890.min.js"></script>
{% endblock %}
{% block sidebartitle %} {{ super() }}
<style>
/* Sidebar header (and topbar for mobile) */
.wy-side-nav-search, .wy-nav-top {
background: #76b900;
}
.wy-menu > p > span.caption-text {
color: #76b900;
}
.wy-menu-vertical p {
height: 32px;
line-height: 32px;
padding: 0 1.618em;
margin: 12px 0 0;
display: block;
font-weight: 700;
text-transform: uppercase;
font-size: 85%;
white-space: nowrap;
}
.wy-side-nav-search a:link, .wy-nav-top a:link {
color: #fff;
}
.wy-side-nav-search a:visited, .wy-nav-top a:visited {
color: #fff;
}
.wy-side-nav-search a:hover, .wy-nav-top a:hover {
color: #fff;
}
.wy-menu-vertical a:link, .wy-menu-vertical a:visited {
color: #d9d9d9
}
.wy-menu-vertical a:active {
background-color: #76b900
}
.wy-side-nav-search>div.version {
color: rgba(0, 0, 0, 0.3)
}
.wy-nav-content {
max-width: 1000px;
}
/* override table width restrictions */
.wy-table-responsive table td, .wy-table-responsive table th {
/* !important prevents the common CSS stylesheets from
overriding this as on RTD they are loaded after this stylesheet */
white-space: normal !important;
}
.wy-table-responsive {
overflow: visible !important;
}
</style>
<style>
a:link, a:visited {
color: #76b900;
}
a:hover {
color: #8c0;
}
html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple)>dt {
background: rgba(118, 185, 0, 0.1);
color: rgba(59,93,0,1);
border-top: solid 3px rgba(59,93,0,1);
}
html.writer-html4 .rst-content dl:not(.docutils) .property, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .property {
text-transform: capitalize;
display: inline-block;
padding-right: 8px;
}
</style>
{% endblock %}
{% block footer %}
<script type="text/javascript">if (typeof _satellite !== undefined){ _satellite.pageBottom();}</script>
{% endblock %}
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
activation.h
============
.. doxygenfile:: activation.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
cast.h
======
.. doxygenfile:: cast.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
fused_attn.h
============
.. doxygenfile:: fused_attn.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
fused_rope.h
============
.. doxygenfile:: fused_rope.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
gemm.h
======
.. doxygenfile:: gemm.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
C/C++ API
=========
The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
directly from C/C++, without Python.
.. toctree::
:caption: Headers
transformer_engine.h <transformer_engine>
activation.h <activation>
cast.h <cast>
fused_attn.h <fused_attn>
fused_rope.h <fused_rope>
gemm.h <gemm>
normalization.h <normalization>
padding.h <padding>
permutation.h <permutation>
recipe.h <recipe>
softmax.h <softmax>
swizzle.h <swizzle>
transpose.h <transpose>
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
normalization.h
===============
.. doxygenfile:: normalization.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
padding.h
=========
.. doxygenfile:: padding.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
permutation.h
=============
.. doxygenfile:: permutation.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
recipe.h
========
.. doxygenfile:: recipe.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
softmax.h
=========
.. doxygenfile:: softmax.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
swizzle.h
=========
.. doxygenfile:: swizzle.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
transformer_engine.h
====================
.. doxygenfile:: transformer_engine.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
transpose.h
===========
.. doxygenfile:: transpose.h
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Common API
==========
.. autoapiclass:: transformer_engine.common.recipe.Format
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)
.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Framework-specific API
======================
.. toctree::
pytorch
jax
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Jax
=======
Pre-defined Variable of Logical Axes
------------------------------------
Variables are available in `transformer_engine.jax.sharding`.
* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
.. autoapiclass:: transformer_engine.jax.MeshResource()
.. autoapifunction:: transformer_engine.jax.fp8_autocast
.. autoapifunction:: transformer_engine.jax.update_collections
.. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
:members: __call__
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
:members: __call__
.. autoapifunction:: transformer_engine.jax.flax.extend_logical_axis_rules
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
pyTorch
=======
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward, set_context_parallel_group
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init
.. autoapifunction:: transformer_engine.pytorch.checkpoint
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
.. autoapifunction:: transformer_engine.pytorch.moe_permute
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index
.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs
.. autoapifunction:: transformer_engine.pytorch.initialize_ub
.. autoapifunction:: transformer_engine.pytorch.destroy_ub
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment