Unverified Commit df39a7c2 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

Docs fix (#2301)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* lines lenght
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* subtitle --- fix in many files:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* cross entropy _input -> input rename
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* a lot of small fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* torch_version() change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add missing module and fix warnings
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* removed training whitespace:
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update docs/api/pytorch.rst
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Fix import
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix more imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix NumPy docstring parameter spacing and indentation

- Standardize parameter documentation to use 'param : type' format (space before and after colon) per NumPy style guide
- Fix inconsistent indentation in cpu_offload.py docstring
- Modified 51 Python files across transformer_engine/pytorch
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ca468ebe
...@@ -382,26 +382,26 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): ...@@ -382,26 +382,26 @@ class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor):
Parameters Parameters
---------- ----------
rowwise_data: torch.Tensor rowwise_data : torch.Tensor
Raw FP4 data in a uint8 tensor (rowwise layout). Raw FP4 data in a uint8 tensor (rowwise layout).
rowwise_scale_inv: torch.Tensor rowwise_scale_inv : torch.Tensor
Reciprocal of the scaling factor applied when Reciprocal of the scaling factor applied when
casting to FP4, i.e. the scaling factor that must casting to FP4, i.e. the scaling factor that must
be applied when casting from FP4 to higher be applied when casting from FP4 to higher
precision (rowwise). precision (rowwise).
columnwise_data: torch.Tensor, optional columnwise_data : torch.Tensor, optional
Raw FP4 data in a uint8 tensor (columnwise layout). Raw FP4 data in a uint8 tensor (columnwise layout).
columnwise_scale_inv: torch.Tensor, optional columnwise_scale_inv : torch.Tensor, optional
Reciprocal of the scaling factor for columnwise FP4 data. Reciprocal of the scaling factor for columnwise FP4 data.
amax_rowwise: torch.Tensor, optional amax_rowwise : torch.Tensor, optional
Rowwise amax tracking tensor. Rowwise amax tracking tensor.
amax_columnwise: torch.Tensor, optional amax_columnwise : torch.Tensor, optional
Columnwise amax tracking tensor. Columnwise amax tracking tensor.
fp4_dtype: TE_DType fp4_dtype : TE_DType
The FP4 data type used for quantization. The FP4 data type used for quantization.
quantizer: Quantizer quantizer : Quantizer
The quantizer instance used for this tensor. The quantizer instance used for this tensor.
dtype: torch.dtype, default = torch.float32 dtype : torch.dtype, default = torch.float32
Nominal tensor datatype, used in dequantize. Nominal tensor datatype, used in dequantize.
""" """
......
...@@ -74,7 +74,7 @@ def cast_master_weights_to_fp8( ...@@ -74,7 +74,7 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights. target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`. manual_post_all_gather_processing : bool, default = `False`.
If False, post processing will be automatically triggered during next forward. If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user. If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True, Note that users must call `post_all_gather_processing` if it's set to True,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch version utilities"""
from __future__ import annotations
import functools
import torch
from packaging.version import Version as PkgVersion
@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
...@@ -10,7 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -10,7 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.torch_version import torch_version
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
...@@ -75,8 +75,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -75,8 +75,8 @@ class TransformerLayer(torch.nn.Module):
.. note:: .. note::
Argument :attr:`attention_mask` in the `forward` call is only used when Argument :attr:`attention_mask` in the :meth:`forward` call is only used when
:attr:`self_attn_mask_type` includes `"padding"` or `"arbitrary"`. :attr:`self_attn_mask_type` includes ``"padding"`` or ``"arbitrary"``.
Parameters Parameters
---------- ----------
...@@ -86,76 +86,76 @@ class TransformerLayer(torch.nn.Module): ...@@ -86,76 +86,76 @@ class TransformerLayer(torch.nn.Module):
intermediate size to which input samples are projected. intermediate size to which input samples are projected.
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
num_gqa_groups : int, default = `None` num_gqa_groups : int, default = None
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
`this paper <https://arxiv.org/pdf/2305.13245.pdf>`_. `this paper <https://arxiv.org/pdf/2305.13245.pdf>`_.
This only affects the keys and values, not the querys. This only affects the keys and values, not the querys.
GQA-1 is equivalent to Multi-Query Attention GQA-1 is equivalent to Multi-Query Attention
(`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H (`MQA <https://arxiv.org/pdf/1911.02150.pdf>`_), while GQA-H
is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. is equivalent to MHA, i.e. ``num_gqa_groups = num_attention_heads``.
layernorm_epsilon : float, default = 1e-5 layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization a value added to the denominator of layer normalization
for numerical stability. for numerical stability.
hidden_dropout: float, default = 0.1 hidden_dropout : float, default = 0.1
dropout probability for the dropout op after FC2 layer. dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1 attention_dropout : float, default = 0.1
dropout probability for the dropout op during multi-head attention. dropout probability for the dropout op during multi-head attention.
init_method : Callable, default = `None` init_method : Callable, default = None
used for initializing weights of QKV and FC1 weights in the following way: used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to ``init_method(weight)``. When set to ``None``, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`. ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
output_layer_init_method : Callable, default = `None` output_layer_init_method : Callable, default = None
used for initializing weights of PROJ and FC2 in the following way: used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to ``output_layer_init_method(weight)``. When set to ``None``, defaults to
`torch.nn.init.normal_(mean=0.0, std=0.023)`. ``torch.nn.init.normal_(mean=0.0, std=0.023)``.
apply_residual_connection_post_layernorm : bool, default = `False` apply_residual_connection_post_layernorm : bool, default = False
if set to `True`, residual connections are taken if set to ``True``, residual connections are taken
from the output of layer norm (default is taken from the output of layer norm (default is taken
from input of layer norm) from input of layer norm)
layer_number: int, default = `None` layer_number : int, default = None
layer number of the current `TransformerLayer` when multiple such modules are layer number of the current :class:`TransformerLayer` when multiple such modules are
concatenated to form a transformer block. concatenated to form a transformer block.
output_layernorm: bool, default = `False` output_layernorm : bool, default = False
if set to `True`, layer normalization is applied on the output side, if set to ``True``, layer normalization is applied on the output side,
after the final dropout-add. default behavior is to apply layer after the final dropout-add. default behavior is to apply layer
normalization on the input side, before the QKV transformation. normalization on the input side, before the QKV transformation.
parallel_attention_mlp: bool, default = `False` parallel_attention_mlp : bool, default = False
if set to `True`, self-attention and feedforward network are computed if set to ``True``, self-attention and feedforward network are computed
based on the same input (in parallel) instead of sequentially. based on the same input (in parallel) instead of sequentially.
Both blocks have an independent normalization. Both blocks have an independent normalization.
This architecture is used in `Falcon` models. This architecture is used in `Falcon` models.
layer_type: {'encoder', 'decoder'}, default = `encoder` layer_type : {'encoder', 'decoder'}, default = "encoder"
if set to `decoder`, an additional cross-attn block is added after self-attn. if set to ``"decoder"``, an additional cross-attn block is added after self-attn.
This can be used for structures like `T5` Transformer in conjunction with the This can be used for structures like `T5` Transformer in conjunction with the
`encoder` option. ``"encoder"`` option.
kv_channels: int, default = `None` kv_channels : int, default = None
number of query-key-value channels per attention head. defaults to number of query-key-value channels per attention head. defaults to
:attr:`hidden_size` / :attr:`num_attention_heads` if `None`. :attr:`hidden_size` / :attr:`num_attention_heads` if ``None``.
self_attn_mask_type: {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right', self_attn_mask_type : {'no_mask', 'padding', 'causal', 'padding_causal', 'causal_bottom_right',
'padding_causal_bottom_right', 'arbitrary'}, 'padding_causal_bottom_right', 'arbitrary'},
default = `causal` default = "causal"
type of attention mask passed into softmax operation for encoder. type of attention mask passed into softmax operation for encoder.
Overridden by :attr:`self_attn_mask_type` in the `forward` method. Overridden by :attr:`self_attn_mask_type` in the :meth:`forward` method.
The forward arg is useful for dynamically changing mask types, e.g. The :meth:`forward` arg is useful for dynamically changing mask types, e.g.
a different mask for training and inference. The init arg is useful a different mask for training and inference. The :meth:`__init__` arg is useful
for cases involving compilation/tracing, e.g. ONNX export. for cases involving compilation/tracing, e.g. ONNX export.
window_size: Optional[Tuple[int, int]], default = `None` window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in encoder, where query at position i sliding window size for local attention in encoder, where query at position i
attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k attends to keys in ``[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k
- seqlen_q + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean - seqlen_q + window_size[1]]`` inclusive. Special cases ``(-1, -1)`` and ``(-1, 0)`` mean
no sliding window and causal mask specifically. Both `causal` and no sliding window and causal mask specifically. Both ``"causal"`` and
`causal_bottom_right` masks map to `window_size = (-1, 0)` and Transformer Engine ``"causal_bottom_right"`` masks map to :attr:`window_size` = ``(-1, 0)`` and Transformer Engine
distinguishes them based on `self_attn_mask_type` or `enc_dec_attn_mask_type`. distinguishes them based on :attr:`self_attn_mask_type` or :attr:`enc_dec_attn_mask_type`.
Similar to :attr:`self_attn_mask_type`, `window_size` can be overridden by Similar to :attr:`self_attn_mask_type`, :attr:`window_size` can be overridden by
:attr:`window_size` in `forward` as well. :attr:`window_size` in :meth:`forward` as well.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, enc_dec_attn_mask_type : {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `no_mask` default = "no_mask"
type of attention mask passed into softmax operation for decoder. type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None` enc_dec_window_size : Optional[Tuple[int, int]], default = None
sliding window size for local attention in decoder. sliding window size for local attention in decoder.
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = False
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and if set to ``True``, gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to the LayerNorm formula changes to
.. math:: .. math::
...@@ -163,111 +163,126 @@ class TransformerLayer(torch.nn.Module): ...@@ -163,111 +163,126 @@ class TransformerLayer(torch.nn.Module):
(1 + \gamma) + \beta (1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied. type of normalization applied.
qkv_weight_interleaved : bool, default = `True` qkv_weight_interleaved : bool, default = True
if set to `False`, the QKV weight is interpreted as a concatenation of if set to ``False``, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default query, key, and value weights along the ``0th`` dimension. The default
interpretation is that the individual `q`, `k`, and `v` weights for each interpretation is that the individual ``q``, ``k``, and ``v`` weights for each
attention head are interleaved. This parameter is set to `False` when attention head are interleaved. This parameter is set to ``False`` when
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False` rotary_pos_interleaved : bool, default = False
whether to use interleaved rotary position embeddings. whether to use interleaved rotary position embeddings.
bias : bool, default = `True` bias : bool, default = True
if set to `False`, the transformer layer will not learn any additive biases. if set to ``False``, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``,
'silu', 'swiglu', and 'clamped_swiglu'. ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``.
activation_params : Optional[dict], default = `None` activation_params : Optional[dict], default = None
Additional parameters for the activation function. Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which At the moment, only used for ``'clamped_swiglu'`` activation which
supports 'limit' and 'alpha' parameters. You can set these as supports ``'limit'`` and ``'alpha'`` parameters. You can set these as
`activation_params={'limit': 7.0, 'alpha': 1.702}`. ``activation_params={'limit': 7.0, 'alpha': 1.702}``.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
forward pass. forward pass.
attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' attn_input_format : {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
This controls whether the dimensions of the This controls whether the dimensions of the
intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'), intermediate hidden states is 'sequence first' (``'sbhd'``), 'batch first' (``'bshd'``),
or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size, or 'token first' (``'thd'``). ``s`` stands for the sequence length, ``b`` batch size,
`t` the total number of tokens, `h` the number of heads, `d` head size. ``t`` the total number of tokens, ``h`` the number of heads, ``d`` head size.
Note that these formats are very closely Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention` related to the :attr:`qkv_format` parameter in the :class:`MultiHeadAttention`
and `DotProductAttention` modules. and :class:`DotProductAttention` modules.
name: str, default = `None` name : str, default = None
name of the module, currently used for debugging purposes. name of the module, currently used for debugging purposes.
softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' softmax_type : str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla'
softmax type as described in this paper: Softmax type as described in the paper
`Efficient Streaming Language Models with Attention Sinks `Efficient Streaming Language Models with Attention Sinks
<https://arxiv.org/pdf/2309.17453v3>`_. <https://arxiv.org/pdf/2309.17453v3>`_.
For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv],
'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), For a given attention score :math:`S = Q \cdot K^T`, of shape ``[b, h, s_q, s_kv]``:
'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), * ``'vanilla'``:
where alpha is a learnable parameter in shape [h].
'off-by-one' and 'learnable' softmax types are also called sink attention .. math::
('zero sink' and 'learnable sink'). Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{\sum_j \exp(S_{:,:,:,j})}
* ``'off-by-one'``:
.. math::
Softmax(S)_{:,:,:,i} = \frac{\exp(S_{:,:,:,i})}{1 + \sum_j \exp(S_{:,:,:,j})}
* ``'learnable'``:
.. math::
Softmax(S)_{:,h,:,i} = \frac{\exp(S_{:,h,:,i})}{\exp(\alpha_h) + \sum_j \exp(S_{:,h,:,j})}
where :math:`\\alpha` is a learnable parameter of shape ``[h]``.
``'off-by-one'`` and ``'learnable'`` softmax types are also called sink attention
(``'zero sink'`` and ``'learnable sink'``).
Parallelism parameters Parallelism parameters
---------------------- ----------------------
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = False
if set to `True`, QKV and FC1 layers are used as Column Parallel if set to ``True``, QKV and FC1 layers are used as Column Parallel
whereas PROJ and FC2 is used as Row Parallel as described whereas PROJ and FC2 is used as Row Parallel as described
`here <https://arxiv.org/pdf/1909.08053.pdf>`_. `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
sequence_parallel : bool, default = `False` sequence_parallel : bool, default = False
if set to `True`, uses sequence parallelism. if set to ``True``, uses sequence parallelism.
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
tp_size : int, default = 1 tp_size : int, default = 1
used as TP (tensor parallel) world size when TP groups are not formed during used as TP (tensor parallel) world size when TP groups are not formed during
initialization. In this case, users must call the initialization. In this case, users must call the
`set_tensor_parallel_group(tp_group)` method on the initialized module before the :meth:`set_tensor_parallel_group` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
Optimization parameters Optimization parameters
----------------------- -----------------------
fuse_wgrad_accumulation : bool, default = 'False' fuse_wgrad_accumulation : bool, default = False
if set to `True`, enables fusing of creation and accumulation of if set to ``True``, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the have an additional :attr:`main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct regular :attr:`grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in. size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.get_default_dtype()` params_dtype : torch.dtype, default = torch.get_default_dtype()
it controls the type used to allocate the initial parameters. Useful when it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory. would not fit in GPU memory.
seq_length: int seq_length : int
sequence length of input samples. Needed for JIT Warmup, a technique where jit sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for fused functions are warmed up before training to ensure same kernels are used for
forward propogation and activation recompute phase. forward propogation and activation recompute phase.
micro_batch_size: int micro_batch_size : int
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propogation and activation recompute phase. used for forward propogation and activation recompute phase.
drop_path_rate: float, default = 0.0 drop_path_rate : float, default = 0.0
when > 0.0, applies stochastic depth per sample in when > 0.0, applies stochastic depth per sample in
the main path of the residual block. the main path of the residual block.
fuse_qkv_params: bool, default = 'False' fuse_qkv_params : bool, default = False
if set to `True`, `TransformerLayer` module exposes a single fused if set to ``True``, :class:`TransformerLayer` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. :attr:`fuse_wgrad_accumulation`.
qk_norm_type: Optional[str], default = None qk_norm_type : Optional[str], default = None
type of normalization to apply to query and key tensors. type of normalization to apply to query and key tensors.
Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. Options: ``None``, ``'L2Normalization'``, ``'RMSNorm'``, ``'LayerNorm'``. When ``None``, no normalization is applied.
When 'L2Normalization', L2 normalization is applied to query and key tensors. When ``'L2Normalization'``, L2 normalization is applied to query and key tensors.
When 'RMSNorm', RMS normalization is applied to query and key tensors. When ``'RMSNorm'``, RMS normalization is applied to query and key tensors.
When 'LayerNorm', layer normalization is applied to query and key tensors. When ``'LayerNorm'``, layer normalization is applied to query and key tensors.
Normalization is applied after RoPE (if applicable) but before attention computation Normalization is applied after RoPE (if applicable) but before attention computation
when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for when ``qk_norm_before_rope`` is ``False``. This follows the e.g. Llama4 approach for
QK normalization to improve training stability and model performance. QK normalization to improve training stability and model performance.
qk_norm_eps: float, default = 1e-6 qk_norm_eps : float, default = 1e-6
epsilon value for normalization of query and key tensors. epsilon value for normalization of query and key tensors.
Only used when `qk_norm_type` is not None. Only used when ``qk_norm_type`` is not ``None``.
qk_norm_before_rope: bool, default = `False` qk_norm_before_rope : bool, default = False
if set to `True`, query and key normalization is applied before rotary position if set to ``True``, query and key normalization is applied before rotary position
embedding. When `False` (default), normalization is applied after RoPE. embedding. When ``False`` (default), normalization is applied after RoPE.
This parameter allows supporting different architectural variants that apply This parameter allows supporting different architectural variants that apply
QK normalization at different points. QK normalization at different points.
""" """
...@@ -523,7 +538,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -523,7 +538,7 @@ class TransformerLayer(torch.nn.Module):
Parameters Parameters
---------- ----------
tp_group : ProcessGroup, default = `None` tp_group : ProcessGroup, default = None
tensor parallel process group. tensor parallel process group.
""" """
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
...@@ -549,7 +564,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -549,7 +564,7 @@ class TransformerLayer(torch.nn.Module):
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
) -> None: ) -> None:
""" r"""
Set the context parallel attributes for the given Set the context parallel attributes for the given
module before executing the forward pass. module before executing the forward pass.
...@@ -557,25 +572,26 @@ class TransformerLayer(torch.nn.Module): ...@@ -557,25 +572,26 @@ class TransformerLayer(torch.nn.Module):
---------- ----------
cp_group : Union[ProcessGroup, List[ProcessGroup]] cp_group : Union[ProcessGroup, List[ProcessGroup]]
context parallel process group. context parallel process group.
ProcessGroup is for cp_comm_type of "p2p", "all_gather", and "a2a". ProcessGroup is for cp_comm_type of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
List[ProcessGroup] is for cp_comm_type of "a2a+p2p", where cp_group[0] List[ProcessGroup] is for cp_comm_type of ``"a2a+p2p"``, where ``cp_group[0]``
and cp_group[1] are for a2a and p2p communications respectively. and ``cp_group[1]`` are for a2a and p2p communications respectively.
cp_global_ranks : List[int] cp_global_ranks : List[int]
list of global ranks in the context group. list of global ranks in the context group.
cp_stream : torch.cuda.Stream cp_stream : torch.cuda.Stream
cuda stream for context parallel execution. cuda stream for context parallel execution.
cp_comm_type : str, default = `p2p` cp_comm_type : str, default = "p2p"
inter-gpu communication type for context parallelism. inter-gpu communication type for context parallelism.
Can be "p2p" or "all_gather" or "a2a", or "a2a+p2p". Can be ``"p2p"`` or ``"all_gather"`` or ``"a2a"`` or ``"a2a+p2p"``.
"p2p": Exchange KV chunks with P2P communications in ring topology.
P2P is async and can be overlapped with attention compute. - ``"p2p"``: Exchange KV chunks with P2P communications in ring topology.
"all_gather": All-gather to get full sequence of KV before attention. P2P is async and can be overlapped with attention compute.
The all-gather is not async, and cannot be overlapped. - ``"all_gather"``: All-gather to get full sequence of KV before attention.
"a2a": Like DeepSpeed Ulysses, scatter attention heads across the CP The all-gather is not async, and cannot be overlapped.
group, and gather to get full sequence of QKV. - ``"a2a"``: Like DeepSpeed Ulysses, scatter attention heads across the CP
"a2a+p2p": hierarchical CP implementation. First applying a2a to QKV group, and gather to get full sequence of QKV.
across each CP sub-group (e.g., via NVLink), then exchanging KV with - ``"a2a+p2p"``: hierarchical CP implementation. First applying a2a to QKV
p2p between sub-groups (e.g., via IBLink). across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink).
""" """
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
...@@ -610,49 +626,49 @@ class TransformerLayer(torch.nn.Module): ...@@ -610,49 +626,49 @@ class TransformerLayer(torch.nn.Module):
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
pad_between_seqs: Optional[bool] = None, pad_between_seqs: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" r"""
Transformer Layer: attention block and a feedforward network (MLP) Transformer Layer: attention block and a feedforward network (MLP)
.. note:: .. note::
Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type` Argument :attr:`attention_mask` is only used when :attr:`self_attn_mask_type`
includes `"padding"` or `"arbitrary"`. includes ``"padding"`` or ``"arbitrary"``.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Input tensor. Input tensor.
attention_mask : Optional[torch.Tensor], default = `None` attention_mask : Optional[torch.Tensor], default = None
Boolean tensor used to mask out self-attention softmax input. It should be Boolean tensor used to mask out self-attention softmax input. It should be
in [batch_size, 1, 1, seqlen_q] for padding masks, and broadcastable in ``[batch_size, 1, 1, seqlen_q]`` for padding masks, and broadcastable
to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] for "`arbitrary`" to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]`` for ``"arbitrary"``
mask. It should be `None` for causal masks and "`no_mask`" type. mask. It should be ``None`` for causal masks and ``"no_mask"`` type.
A `True` value means the corresponding position is masked out and A ``True`` value means the corresponding position is masked out and
a `False` means that position is allowed to participate in attention. a ``False`` means that position is allowed to participate in attention.
self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', self_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal',
'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'}, 'causal_bottom_right', 'padding_causal_bottom_right','arbitrary'},
default = `causal` default = "causal"
Type of attention mask passed into softmax operation for encoder. Type of attention mask passed into softmax operation for encoder.
By default, causal masks are aligned to the top left corner of By default, causal masks are aligned to the top left corner of
the softmax matrix. When "`bottom_right`" is specified in the mask type, the softmax matrix. When ``"bottom_right"`` is specified in the mask type,
causal masks are aligned to the bottom right corner. causal masks are aligned to the bottom right corner.
window_size: Optional[Tuple[int, int]], default = `None` window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in encoder. Sliding window size for local attention in encoder.
encoder_output : Optional[torch.Tensor], default = `None` encoder_output : Optional[torch.Tensor], default = None
Output of the encoder block to be fed into the decoder block if using Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`. :attr:`layer_type` = ``"decoder"``.
enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], enc_dec_attn_mask : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
default = `None`. Boolean tensors used to mask out inter-attention softmax input if default = None. Boolean tensors used to mask out inter-attention softmax input if
using `layer_type="decoder"`. It should be a tuple of two masks in using :attr:`layer_type` = ``"decoder"``. It should be a tuple of two masks in
[batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv] for padding masks. ``[batch_size, 1, 1, seqlen_q]`` and ``[batch_size, 1, 1, seqlen_kv]`` for padding masks.
It should be broadcastable to [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] It should be broadcastable to ``[batch_size, num_heads, max_seqlen_q, max_seqlen_kv]``
for "`arbitrary`" mask. It should be `None` for causal masks and "`no_mask`". for ``"arbitrary"`` mask. It should be ``None`` for causal masks and ``"no_mask"``.
A `True` value means the corresponding position is masked out and a `False` A ``True`` value means the corresponding position is masked out and a ``False``
means that position is allowed to participate in attention. means that position is allowed to participate in attention.
enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'}, enc_dec_attn_mask_type: {'no_mask', 'causal', 'padding', 'padding_causal', 'arbitrary'},
default = `None` default = None
Type of attention mask passed into softmax operation for decoder. Type of attention mask passed into softmax operation for decoder.
enc_dec_window_size: Optional[Tuple[int, int]], default = `None` enc_dec_window_size: Optional[Tuple[int, int]], default = None
Sliding window size for local attention in decoder. Sliding window size for local attention in decoder.
is_first_microbatch : {True, False, None}, default = None is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or During training using either gradient accumulation or
...@@ -667,53 +683,53 @@ class TransformerLayer(torch.nn.Module): ...@@ -667,53 +683,53 @@ class TransformerLayer(torch.nn.Module):
* it also allows skipping gradient accumulation during the * it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being first microbatch (since it is the first gradient being
produced) produced)
checkpoint_core_attention: bool, default = `False` checkpoint_core_attention: bool, default = False
If true, forward activations for core attention are recomputed If ``True``, forward activations for core attention are recomputed
during the backward pass in order to save memory that would during the backward pass in order to save memory that would
otherwise be occupied to store the forward activations until otherwise be occupied to store the forward activations until
backprop. backprop.
rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = `None` rotary_pos_emb: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], default = None
Embeddings for query and key tensors for applying rotary position Embeddings for query and key tensors for applying rotary position
embedding. By default no input embedding is applied. embedding. By default no input embedding is applied.
core_attention_bias_type: str, default = `no_bias` core_attention_bias_type: str, default = "no_bias"
Bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`} Bias type, {``"no_bias"``, ``"pre_scale_bias"``, ``"post_scale_bias"``, ``"alibi"``}
core_attention_bias: Optional[torch.Tensor], default = `None` core_attention_bias: Optional[torch.Tensor], default = None
Bias tensor for Q * K.T Bias tensor for :math:`Q \cdot K^T`
alibi_slopes: Optional[torch.Tensor], default = `None` alibi_slopes: Optional[torch.Tensor], default = None
ALiBi slopes in FP32 and shape [nheads] or [batch_size, nheads]. ALiBi slopes in FP32 and shape ``[nheads]`` or ``[batch_size, nheads]``.
It adds a bias of (-alibi_slope * (i + seqlen_k - seqlen_q - j)) It adds a bias of :math:`(-\text{alibi_slope} \cdot (i + \text{seqlen_k} - \text{seqlen_q} - j))`
to the attention score of query i and key j. to the attention score of query i and key j.
cu_seqlens_q: Optional[torch.Tensor], default = `None` cu_seqlens_q: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (without offset) in a batch for query layer,
with shape [batch_size + 1] and dtype torch.int32. with shape ``[batch_size + 1]`` and dtype torch.int32.
Used by encoders, or decoders' self-attention. Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for key layer
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and value layer, with shape ``[batch_size + 1]`` and dtype torch.int32.
Used by decoders' cross-attention. Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None` cu_seqlens_q_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`, Cumulative sum of sequence lengths (with offset) in a batch for query layer,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None. with shape ``[batch_size + 1]`` and dtype torch.int32. Set to :attr:`cu_seqlens_q` if ``None``.
Used by encoders, or decoders' self-attention. Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None` cu_seqlens_kv_padded: Optional[torch.Tensor], default = None
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer` Cumulative sum of sequence lengths (with offset) in a batch for key layer
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and value layer, with shape ``[batch_size + 1]`` and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention. Set to :attr:`cu_seqlens_kv` if ``None``. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = None
Maximum sequence length in `query_layer`. Maximum sequence length in query layer.
Calculated from `cu_seqlens_q_padded` if not provided. Calculated from :attr:`cu_seqlens_q_padded` if not provided.
max_seqlen_kv: Optional[int], default = `None` max_seqlen_kv: Optional[int], default = None
Maximum sequence length in `key_layer` and `value_layer`. Maximum sequence length in key layer and value layer.
Calculated from `cu_seqlens_kv_padded` if not provided. Calculated from :attr:`cu_seqlens_kv_padded` if not provided.
fast_zero_fill: bool, default = `True` fast_zero_fill: bool, default = True
Whether to set output tensors to 0 or not before use. Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None inference_params: InferenceParams, default = None
Inference parameters that are passed to the main model in order Inference parameters that are passed to the main model in order
to efficiently calculate and store the context during inference. to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None` pad_between_seqs: Optional[bool], default = None
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If ``None``, inferred from :attr:`qkv_format`, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch, If ``True``, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'. i.e. :attr:`qkv_format` = ``'thd'``.
""" """
if self_attn_mask_type is None: if self_attn_mask_type is None:
......
...@@ -31,18 +31,18 @@ def make_row_id_map( ...@@ -31,18 +31,18 @@ def make_row_id_map(
Parameters Parameters
---------- ----------
routing_map: torch.Tensor routing_map : torch.Tensor
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
which experts are routed to which tokens. The values in it: 1 means the token is routed to which experts are routed to which tokens. The values in it: 1 means the token is routed to
this expert and 0 means not. this expert and 0 means not.
num_tokens: int num_tokens : int
Number of tokens in the input tensor. Number of tokens in the input tensor.
num_experts: int num_experts : int
Number of experts in the input tensor. Number of experts in the input tensor.
Returns Returns
------- -------
row_id_map: torch.Tensor row_id_map : torch.Tensor
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`. The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
For each token, the last item is the number of experts that are routed (n_routed). For each token, the last item is the number of experts that are routed (n_routed).
The first n_routed items are the destination row indices in the permuted tokens. The first n_routed items are the destination row indices in the permuted tokens.
...@@ -134,23 +134,23 @@ def permute_with_mask_map( ...@@ -134,23 +134,23 @@ def permute_with_mask_map(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map: torch.Tensor row_id_map : torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs: torch.Tensor probs : torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted. The probabilities of the input tensor. If it is not None, it will be permuted.
scale: torch.Tensor scale : torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted. The scale of the input tensor. If it is not None, it will be permuted.
num_tokens: int num_tokens : int
Number of tokens in the input tensor. Number of tokens in the input tensor.
num_experts: int num_experts : int
Number of experts in the input tensor. Number of experts in the input tensor.
num_out_tokens: int num_out_tokens : int
Number of tokens in the permuted tensor. Number of tokens in the permuted tensor.
hidden_size: int hidden_size : int
Hidden size of the input tensor. Hidden size of the input tensor.
scale_hidden_dim: int scale_hidden_dim : int
Hidden size of the scale tensor. Hidden size of the scale tensor.
""" """
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
...@@ -211,20 +211,20 @@ def unpermute_with_mask_map( ...@@ -211,20 +211,20 @@ def unpermute_with_mask_map(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_out_tokens, hidden_size]`. Input tensor of shape `[num_out_tokens, hidden_size]`.
row_id_map: torch.Tensor row_id_map : torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs: torch.Tensor merging_probs : torch.Tensor
The merging probabilities of the input tensor. If it is not None, it will be used as weights The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens. to reduce the unpermuted tokens.
permuted_probs: torch.Tensor permuted_probs : torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
num_tokens: int num_tokens : int
Number of tokens in the permuted tensor. Number of tokens in the permuted tensor.
num_experts: int num_experts : int
Number of experts in the permuted tensor. Number of experts in the permuted tensor.
hidden_size: int hidden_size : int
Hidden size of the permuted tensor. Hidden size of the permuted tensor.
""" """
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
...@@ -278,21 +278,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -278,21 +278,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
Parameters Parameters
---------- ----------
fwd_output_grad: torch.Tensor fwd_output_grad : torch.Tensor
The gradient of the output tensor of shape `[num_tokens, hidden_size]`. The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor row_id_map : torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input: torch.Tensor fwd_input : torch.Tensor
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`. The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs: torch.Tensor merging_probs : torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`. The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
num_tokens: int num_tokens : int
Number of tokens in the permuted tensor. Number of tokens in the permuted tensor.
num_experts: int num_experts : int
Number of experts in the permuted tensor. Number of experts in the permuted tensor.
num_out_tokens: int num_out_tokens : int
Number of tokens in the output tensor. Number of tokens in the output tensor.
hidden_size: int hidden_size : int
Hidden size of the output tensor. Hidden size of the output tensor.
""" """
act_grad = torch.empty( act_grad = torch.empty(
...@@ -339,13 +339,13 @@ def make_chunk_sort_map( ...@@ -339,13 +339,13 @@ def make_chunk_sort_map(
Parameters Parameters
---------- ----------
split_sizes: torch.Tensor split_sizes : torch.Tensor
The sizes of the chunks of shape `[num_splits,]`. The sizes of the chunks of shape `[num_splits,]`.
sorted_indices: torch.Tensor sorted_indices : torch.Tensor
The indices of the sorted chunks of shape `[num_splits,]`. The indices of the sorted chunks of shape `[num_splits,]`.
num_tokens: int num_tokens : int
Number of tokens in the input tensor. Number of tokens in the input tensor.
num_splits: int num_splits : int
Number of splits of split_sizes and sorted_indices. Number of splits of split_sizes and sorted_indices.
""" """
row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda") row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
...@@ -373,17 +373,17 @@ def sort_chunks_by_map( ...@@ -373,17 +373,17 @@ def sort_chunks_by_map(
Parameters Parameters
---------- ----------
inp: torch.Tensor inp : torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`. Input tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor row_id_map : torch.Tensor
The token to expert mapping tensor of shape `[num_tokens,]`. The token to expert mapping tensor of shape `[num_tokens,]`.
probs: torch.Tensor probs : torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted. The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens: int num_tokens : int
Number of tokens in the input tensor. Number of tokens in the input tensor.
hidden_size: int hidden_size : int
Hidden size of the input tensor. Hidden size of the input tensor.
is_forward: bool is_forward : bool
Whether the sort is for forward or backward. Whether the sort is for forward or backward.
""" """
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
......
...@@ -12,8 +12,8 @@ from contextlib import nullcontext ...@@ -12,8 +12,8 @@ from contextlib import nullcontext
import numpy as np import numpy as np
import torch import torch
from . import torch_version
from .quantized_tensor import Quantizer from .quantized_tensor import Quantizer
from .torch_version import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
...@@ -601,7 +601,7 @@ def get_nvtx_range_context(msg: str): ...@@ -601,7 +601,7 @@ def get_nvtx_range_context(msg: str):
Parameters Parameters
---------- ----------
msg: str msg : str
Message to associate with profiling context. Message to associate with profiling context.
""" """
...@@ -619,7 +619,7 @@ def nvtx_range_push(msg: str) -> None: ...@@ -619,7 +619,7 @@ def nvtx_range_push(msg: str) -> None:
Parameters Parameters
---------- ----------
msg: str msg : str
Message to associate with range Message to associate with range
""" """
...@@ -637,7 +637,7 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: ...@@ -637,7 +637,7 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
Parameters Parameters
---------- ----------
msg: str, optional msg : str, optional
Message associated with range Message associated with range
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment