Unverified Commit df39a7c2 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Docs fix (#2301)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lines lenght
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* subtitle --- fix in many files:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* a lot of small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* torch_version() change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing module and fix warnings
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* removed training whitespace:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update docs/api/pytorch.rst
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Fix import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix more imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix NumPy docstring parameter spacing and indentation

- Standardize parameter documentation to use 'param : type' format (space before and after colon) per NumPy style guide
- Fix inconsistent indentation in cpu_offload.py docstring
- Modified 51 Python files across transformer_engine/pytorch
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ca468ebe
...@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation): ...@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
Parameters Parameters
---------- ----------
normalized_shape: int or iterable of int normalized_shape : int or iterable of int
Inner dimensions of input tensor Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
A value added to the denominator for numerical stability A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero If `True`, the :math:`\gamma` parameter is initialized to zero
...@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation): ...@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
.. math:: .. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0 sm_margin : int, default = 0
Number of SMs to exclude when launching CUDA kernels. This Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM For more fine-grained control, provide a dict with the SM
......
...@@ -90,15 +90,15 @@ def fuse_backward_activation_bias( ...@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
recipe: Recipe, optional recipe : Recipe, optional
Used quantization recipe Used quantization recipe
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm( ...@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -119,13 +119,13 @@ def fuse_backward_linear_add( ...@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -119,13 +119,13 @@ def fuse_backward_linear_scale( ...@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation( ...@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add( ...@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add( ...@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -589,13 +589,13 @@ def fuse_userbuffers_backward_linear( ...@@ -589,13 +589,13 @@ def fuse_userbuffers_backward_linear(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -377,13 +377,13 @@ def fuse_userbuffers_forward_linear( ...@@ -377,13 +377,13 @@ def fuse_userbuffers_forward_linear(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -310,7 +310,7 @@ class OperationFuser: ...@@ -310,7 +310,7 @@ class OperationFuser:
Parameters Parameters
---------- ----------
ops: list of FusibleOperation ops : list of FusibleOperation
Pipeline of operations Pipeline of operations
""" """
......
...@@ -27,29 +27,29 @@ class Linear(FusedOperation): ...@@ -27,29 +27,29 @@ class Linear(FusedOperation):
Parameters Parameters
---------- ----------
in_features: int in_features : int
Inner dimension of input tensor Inner dimension of input tensor
out_features: int out_features : int
Inner dimension of output tensor Inner dimension of output tensor
bias: bool, default = `True` bias : bool, default = `True`
Apply additive bias Apply additive bias
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode : {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group tensor_parallel_group : torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism Process group for tensor parallelism
sequence_parallel: bool, default = `False` sequence_parallel : bool, default = `False`
Whether to apply sequence parallelism together with tensor Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim) along inner dimension (embedding dim)
rng_state_tracker_function: callable rng_state_tracker_function : callable
Function that returns CudaRNGStatesTracker, which is used for Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False` accumulate_into_main_grad : bool, default = `False`
Whether to directly accumulate weight gradients into the Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
......
...@@ -684,7 +684,7 @@ class FusedOperation(FusibleOperation): ...@@ -684,7 +684,7 @@ class FusedOperation(FusibleOperation):
Parameters Parameters
---------- ----------
basic_ops: iterable of FusibleOperation basic_ops : iterable of FusibleOperation
Basic ops that are interchangeable with this op Basic ops that are interchangeable with this op
""" """
......
...@@ -514,22 +514,22 @@ def moe_permute( ...@@ -514,22 +514,22 @@ def moe_permute(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
routing_map: torch.Tensor routing_map : torch.Tensor
The token to expert mapping tensor. The token to expert mapping tensor.
If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'. If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not. The values in it: 1 means the token is routed to this expert and 0 means not.
If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'. If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'.
The values in it are the routed expert indices. The values in it are the routed expert indices.
num_out_tokens: int, default = -1 num_out_tokens : int, default = -1
The effective output token count, representing the number of tokens not dropped. The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped. By default, set to '-1', meaning no tokens are dropped.
max_token_num: int, default = -1 max_token_num : int, default = -1
The maximum number of tokens, used for workspace allocation. The maximum number of tokens, used for workspace allocation.
By default, set to '-1', meaning the calculation of the size of workspace is By default, set to '-1', meaning the calculation of the size of workspace is
automatically taken over by the operator. automatically taken over by the operator.
map_type: str, default = 'mask' map_type : str, default = 'mask'
Type of the routing map tensor. Type of the routing map tensor.
Options are: 'mask', 'index'. Options are: 'mask', 'index'.
Refer to `routing_map` for more details. Refer to `routing_map` for more details.
...@@ -556,16 +556,16 @@ def moe_permute_with_probs( ...@@ -556,16 +556,16 @@ def moe_permute_with_probs(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor probs : torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens, num_experts]. It will be permuted with the tokens of shape [num_tokens, num_experts]. It will be permuted with the tokens
according to the routing_map. according to the routing_map.
routing_map: torch.Tensor routing_map : torch.Tensor
The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'.
The values in it: 1 means the token is routed to this expert and 0 means not. The values in it: 1 means the token is routed to this expert and 0 means not.
num_out_tokens: int, default = -1 num_out_tokens : int, default = -1
The effective output token count, representing the number of tokens not dropped. The effective output token count, representing the number of tokens not dropped.
By default, set to '-1', meaning no tokens are dropped. By default, set to '-1', meaning no tokens are dropped.
""" """
...@@ -589,21 +589,21 @@ def moe_unpermute( ...@@ -589,21 +589,21 @@ def moe_unpermute(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted. Input tensor with permuted tokens of shape `[num_tokens, hidden_size]` to be unpermuted.
row_id_map: torch.Tensor row_id_map : torch.Tensor
The tensor of a mapping table for sorted indices used to unpermute the tokens, The tensor of a mapping table for sorted indices used to unpermute the tokens,
which is the second output tensor of `Permute`. which is the second output tensor of `Permute`.
merging_probs: torch.Tensor, default = None merging_probs : torch.Tensor, default = None
The tensor of probabilities corresponding to the permuted tokens. If provided, The tensor of probabilities corresponding to the permuted tokens. If provided,
the unpermuted tokens will be merged with their respective probabilities. the unpermuted tokens will be merged with their respective probabilities.
By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation.
restore_shape: torch.Size, default = None restore_shape : torch.Size, default = None
The output shape after the unpermute operation. The output shape after the unpermute operation.
map_type: str, default = 'mask' map_type : str, default = 'mask'
Type of the routing map tensor. Should be the same as the value passed to moe_permute. Type of the routing map tensor. Should be the same as the value passed to moe_permute.
Options are: 'mask', 'index'. Options are: 'mask', 'index'.
probs: torch.Tensor, default = None probs : torch.Tensor, default = None
Renamed to merging_probs. Keep for backward compatibility. Renamed to merging_probs. Keep for backward compatibility.
""" """
if probs is not None: if probs is not None:
...@@ -733,11 +733,11 @@ def moe_sort_chunks_by_index( ...@@ -733,11 +733,11 @@ def moe_sort_chunks_by_index(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
split_sizes: torch.Tensor split_sizes : torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension. Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor sorted_indices : torch.Tensor
Chunk indices used to permute the chunks. Chunk indices used to permute the chunks.
""" """
output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None) output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None)
...@@ -757,15 +757,15 @@ def moe_sort_chunks_by_index_with_probs( ...@@ -757,15 +757,15 @@ def moe_sort_chunks_by_index_with_probs(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
probs: torch.Tensor probs : torch.Tensor
The tensor of probabilities corresponding to the permuted tokens and is The tensor of probabilities corresponding to the permuted tokens and is
of shape [num_tokens]. It will be permuted with the tokens according to of shape [num_tokens]. It will be permuted with the tokens according to
the split_sizes and sorted_indices. the split_sizes and sorted_indices.
split_sizes: torch.Tensor split_sizes : torch.Tensor
Chunk sizes of the inp tensor along the 0-th dimension. Chunk sizes of the inp tensor along the 0-th dimension.
sorted_indices: torch.Tensor sorted_indices : torch.Tensor
Chunk indices used to permute the chunks. Chunk indices used to permute the chunks.
""" """
output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs) output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs)
......
...@@ -26,8 +26,8 @@ from transformer_engine.common.recipe import ( ...@@ -26,8 +26,8 @@ from transformer_engine.common.recipe import (
NVFP4BlockScaling, NVFP4BlockScaling,
CustomRecipe, CustomRecipe,
) )
from .constants import dist_group_type from .constants import dist_group_type
from .utils import get_device_compute_capability from .utils import get_device_compute_capability
from .jit import jit_fuser from .jit import jit_fuser
...@@ -678,7 +678,7 @@ def fp8_model_init( ...@@ -678,7 +678,7 @@ def fp8_model_init(
.. warning:: .. warning::
fp8_model_init is deprecated and will be removed in a future release. Use fp8_model_init is deprecated and will be removed in a future release. Use
quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead. ``quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...)`` instead.
""" """
...@@ -723,7 +723,7 @@ def quantized_model_init( ...@@ -723,7 +723,7 @@ def quantized_model_init(
Parameters Parameters
---------- ----------
enabled: bool, default = `True` enabled : bool, default = True
when enabled, Transformer Engine modules created inside this `quantized_model_init` when enabled, Transformer Engine modules created inside this `quantized_model_init`
region will hold only quantized copies of its parameters, as opposed to the default region will hold only quantized copies of its parameters, as opposed to the default
behavior where both higher precision and quantized copies are present. Setting this behavior where both higher precision and quantized copies are present. Setting this
...@@ -734,9 +734,9 @@ def quantized_model_init( ...@@ -734,9 +734,9 @@ def quantized_model_init(
precision copies of weights are already present in the optimizer. precision copies of weights are already present in the optimizer.
* inference, where only the quantized copies of the parameters are used. * inference, where only the quantized copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change. * LoRA-like fine-tuning, where the main parameters of the model do not change.
recipe: transformer_engine.common.recipe.Recipe, default = `None` recipe : transformer_engine.common.recipe.Recipe, default = None
Recipe used to create the parameters. If left to None, it uses the default recipe. Recipe used to create the parameters. If left to None, it uses the default recipe.
preserve_high_precision_init_val: bool, default = `False` preserve_high_precision_init_val : bool, default = False
when enabled, store the high precision tensor used to initialize quantized parameters when enabled, store the high precision tensor used to initialize quantized parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()` in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to quantized parameters to get/clear this high and `clear_high_precision_init_val()` to quantized parameters to get/clear this high
...@@ -773,8 +773,8 @@ def fp8_autocast( ...@@ -773,8 +773,8 @@ def fp8_autocast(
""" """
.. warning:: .. warning::
fp8_autocast is deprecated and will be removed in a future release. ``fp8_autocast`` is deprecated and will be removed in a future release.
Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead. Use ``autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...)`` instead.
""" """
...@@ -828,16 +828,16 @@ def autocast( ...@@ -828,16 +828,16 @@ def autocast(
Parameters Parameters
---------- ----------
enabled: bool, default = `True` enabled : bool, default = True
whether or not to enable low precision quantization (FP8/FP4). whether or not to enable low precision quantization (FP8/FP4).
calibrating: bool, default = `False` calibrating : bool, default = False
calibration mode allows collecting statistics such as amax and scale calibration mode allows collecting statistics such as amax and scale
data of quantized tensors even when executing without quantization enabled. data of quantized tensors even when executing without quantization enabled.
This is useful for saving an inference ready checkpoint while training This is useful for saving an inference ready checkpoint while training
using a higher precision. using a higher precision.
recipe: recipe.Recipe, default = `None` recipe : recipe.Recipe, default = None
recipe used for low precision quantization. recipe used for low precision quantization.
amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` amax_reduction_group : torch._C._distributed_c10d.ProcessGroup, default = None
distributed group over which amaxes for the quantized tensors distributed group over which amaxes for the quantized tensors
are reduced at the end of each training step. are reduced at the end of each training step.
""" """
......
...@@ -27,7 +27,7 @@ _quantized_tensor_cpu_supported_ops = ( ...@@ -27,7 +27,7 @@ _quantized_tensor_cpu_supported_ops = (
class QuantizedTensorStorage: class QuantizedTensorStorage:
r"""Base class for all *TensorStorage classes. r"""Base class for all TensorStorage classes.
This class (and its subclasses) are optimization for when This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully the full QuantizedTensor is not needed (when it is fully
...@@ -54,11 +54,11 @@ class QuantizedTensorStorage: ...@@ -54,11 +54,11 @@ class QuantizedTensorStorage:
Parameters Parameters
---------- ----------
rowwise_usage : Optional[bool[, default = `None` rowwise_usage : Optional[bool[, default = None
Whether to create or keep the data needed for using the tensor Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None` in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor. preserves the original value in the tensor.
columnwise_usage : Optional[bool], default = `None` columnwise_usage : Optional[bool], default = None
Whether to create or keep the data needed for using the tensor Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor. `None` preserves the original value in the tensor.
...@@ -128,7 +128,7 @@ def prepare_for_saving( ...@@ -128,7 +128,7 @@ def prepare_for_saving(
]: ]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only """Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal *TensorStorage types too.""" the internal TensorStorage types too."""
tensor_list, tensor_objects_list = [], [] tensor_list, tensor_objects_list = [], []
for tensor in tensors: for tensor in tensors:
......
...@@ -92,24 +92,24 @@ def fused_topk_with_score_function( ...@@ -92,24 +92,24 @@ def fused_topk_with_score_function(
Fused topk with score function router. Fused topk with score function router.
Parameters Parameters
---------- ----------
logits: torch.Tensor logits : torch.Tensor
topk: int topk : int
use_pre_softmax: bool use_pre_softmax : bool
if enabled, the computation order: softmax -> topk if enabled, the computation order: softmax -> topk
num_groups: int num_groups : int
used in the group topk used in the group topk
group_topk: int group_topk : int
used in the group topk used in the group topk
scaling_factor: float scaling_factor : float
score_function: str score_function : str
currently only support softmax and sigmoid currently only support softmax and sigmoid
expert_bias: torch.Tensor expert_bias : torch.Tensor
could be used in the sigmoid could be used in the sigmoid
Returns Returns
------- -------
probs: torch.Tensor probs : torch.Tensor
routing_map: torch.Tensor routing_map : torch.Tensor
""" """
if logits.dtype == torch.float64: if logits.dtype == torch.float64:
raise ValueError("Current TE does not support float64 router type") raise ValueError("Current TE does not support float64 router type")
...@@ -186,15 +186,15 @@ def fused_compute_score_for_moe_aux_loss( ...@@ -186,15 +186,15 @@ def fused_compute_score_for_moe_aux_loss(
Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function. Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function.
Parameters Parameters
---------- ----------
logits: torch.Tensor logits : torch.Tensor
topk: int topk : int
score_function: str score_function : str
currently only support softmax and sigmoid currently only support softmax and sigmoid
Returns Returns
------- -------
routing_map: torch.Tensor routing_map : torch.Tensor
scores: torch.Tensor scores : torch.Tensor
""" """
return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function) return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function)
...@@ -258,18 +258,18 @@ def fused_moe_aux_loss( ...@@ -258,18 +258,18 @@ def fused_moe_aux_loss(
Fused MoE aux loss. Fused MoE aux loss.
Parameters Parameters
---------- ----------
probs: torch.Tensor probs : torch.Tensor
tokens_per_expert: torch.Tensor tokens_per_expert : torch.Tensor
the number of tokens per expert the number of tokens per expert
total_num_tokens: int total_num_tokens : int
the total number of tokens, involved in the aux loss calculation the total number of tokens, involved in the aux loss calculation
num_experts: int num_experts : int
topk: int topk : int
coeff: float coeff : float
the coefficient of the aux loss the coefficient of the aux loss
Returns Returns
------- -------
aux_loss: torch.scalar aux_loss : torch.scalar
""" """
return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff) return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)
...@@ -307,18 +307,18 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): ...@@ -307,18 +307,18 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor):
Parameters Parameters
---------- ----------
rowwise_data: torch.Tensor rowwise_data : torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor. FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor rowwise_scale_inv : torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data. FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor] columnwise_data : Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose. FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor] columnwise_scale_inv : Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data. FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 fp8_dtype : transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format. FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and quantizer : Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes. holds configuration about quantization and dequantization modes.
""" """
......
...@@ -453,23 +453,23 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -453,23 +453,23 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
Parameters Parameters
---------- ----------
shape: int or iterable of int shape : int or iterable of int
Tensor dimensions. Tensor dimensions.
dtype: torch.dtype dtype : torch.dtype
Nominal tensor datatype. Nominal tensor datatype.
requires_grad: bool, optional = False requires_grad : bool, optional = False
Whether to compute gradients for this tensor. Whether to compute gradients for this tensor.
data: torch.Tensor data : torch.Tensor
Raw FP8 data in a uint8 tensor Raw FP8 data in a uint8 tensor
fp8_scale_inv: torch.Tensor fp8_scale_inv : torch.Tensor
Reciprocal of the scaling factor applied when casting to FP8, Reciprocal of the scaling factor applied when casting to FP8,
i.e. the scaling factor that must be applied when casting from i.e. the scaling factor that must be applied when casting from
FP8 to higher precision. FP8 to higher precision.
fp8_dtype: transformer_engine_torch.DType fp8_dtype : transformer_engine_torch.DType
FP8 format. FP8 format.
data_transpose: torch.Tensor, optional data_transpose : torch.Tensor, optional
FP8 transpose data in a uint8 tensor FP8 transpose data in a uint8 tensor
quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional quantizer : Float8Quantizer, Float8CurrentScalingQuantizer, optional
Builder class for FP8 tensors Builder class for FP8 tensors
""" """
......
...@@ -204,16 +204,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -204,16 +204,16 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Parameters Parameters
---------- ----------
data: torch.Tensor data : torch.Tensor
Raw FP8 data in a uint8 tensor Raw FP8 data in a uint8 tensor
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 fp8_dtype : transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format. FP8 format.
fp8_scale_inv: torch.Tensor fp8_scale_inv : torch.Tensor
Reciprocal of the scaling factor applied when Reciprocal of the scaling factor applied when
casting to FP8, i.e. the scaling factor that must casting to FP8, i.e. the scaling factor that must
be applied when casting from FP8 to higher be applied when casting from FP8 to higher
precision. precision.
dtype: torch.dtype, default = torch.float32 dtype : torch.dtype, default = torch.float32
Nominal tensor datatype. Nominal tensor datatype.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment