pytorch.rst 6.01 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8

    See LICENSE for license information.

pyTorch
=======

9
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
10
  :members: forward, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
11

12
13
14
.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
  :members: forward, set_tensor_parallel_group

15
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
16

17
18
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)

19
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
20
  :members: forward, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
21

22
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
23
  :members: forward, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
24

25
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
26
  :members: forward, set_context_parallel_group
cyanguwa's avatar
cyanguwa committed
27

28
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
29
  :members: forward, set_context_parallel_group, set_tensor_parallel_group
30

31
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
32
  :members: forward, set_context_parallel_group, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
33

34
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
35
  :members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
36

37
38
39
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
  :members: reset, get_states, set_states, add, fork

40
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
41

42
43
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init

44
45
46
47
.. autoapifunction:: transformer_engine.pytorch.autocast

.. autoapifunction:: transformer_engine.pytorch.quantized_model_init

48
.. autoapifunction:: transformer_engine.pytorch.checkpoint
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
.. autoapifunction:: transformer_engine.pytorch.is_fp8_available

.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available

.. autoapifunction:: transformer_engine.pytorch.is_fp8_block_scaling_available

.. autoapifunction:: transformer_engine.pytorch.is_nvfp4_available

.. autoapifunction:: transformer_engine.pytorch.is_bf16_available

.. autoapifunction:: transformer_engine.pytorch.get_cudnn_version

.. autoapifunction:: transformer_engine.pytorch.get_device_compute_capability

.. autoapifunction:: transformer_engine.pytorch.get_default_recipe

66
67
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

68
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
69
70
71

.. autoapifunction:: transformer_engine.pytorch.moe_permute

72
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
73

74
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
75

76
77
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index

78
79
.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

80
81
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs

82
83
84
.. autoapifunction:: transformer_engine.pytorch.initialize_ub

.. autoapifunction:: transformer_engine.pytorch.destroy_ub
85
86

.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
87
  :members: FP8, NONE
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

Quantized tensors
-----------------

.. autoapiclass:: transformer_engine.pytorch.QuantizedTensorStorage
   :members: update_usage, prepare_for_saving, restore_from_saved

.. autoapiclass:: transformer_engine.pytorch.QuantizedTensor(shape, dtype, *, requires_grad=False, device=None)
   :members: dequantize, quantize_

.. autoapiclass:: transformer_engine.pytorch.Float8TensorStorage(data, fp8_scale_inv, fp8_dtype, data_transpose=None, quantizer=None)

.. autoapiclass:: transformer_engine.pytorch.MXFP8TensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer)

.. autoapiclass:: transformer_engine.pytorch.Float8BlockwiseQTensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer, is_2D_scaled, data_format)

.. autoapiclass:: transformer_engine.pytorch.NVFP4TensorStorage(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, amax_rowwise, amax_columnwise, fp4_dtype, quantizer)

.. autoapiclass:: transformer_engine.pytorch.Float8Tensor(shape, dtype, data, fp8_scale_inv, fp8_dtype, requires_grad=False, data_transpose=None, quantizer=None)

.. autoapiclass:: transformer_engine.pytorch.MXFP8Tensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer)

.. autoapiclass:: transformer_engine.pytorch.Float8BlockwiseQTensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, fp8_dtype, quantizer, is_2D_scaled, data_format)

.. autoapiclass:: transformer_engine.pytorch.NVFP4Tensor(rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, amax_rowwise, amax_columnwise, fp4_dtype, quantizer)

Quantizers
----------

.. autoapiclass:: transformer_engine.pytorch.Quantizer(rowwise, columnwise)
   :members: update_quantized, quantize

.. autoapiclass:: transformer_engine.pytorch.Float8Quantizer(scale, amax, fp8_dtype, *, rowwise=True, columnwise=True)

.. autoapiclass:: transformer_engine.pytorch.Float8CurrentScalingQuantizer(fp8_dtype, device, *, rowwise=True, columnwise=True, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.MXFP8Quantizer(fp8_dtype, *, rowwise=True, columnwise=True)

.. autoapiclass:: transformer_engine.pytorch.Float8BlockQuantizer(fp8_dtype, *, rowwise, columnwise, **kwargs)

.. autoapiclass:: transformer_engine.pytorch.NVFP4Quantizer(fp4_dtype, *, rowwise=True, columnwise=True, **kwargs)

Tensor saving and restoring functions
-------------------------------------

.. autoapifunction:: transformer_engine.pytorch.prepare_for_saving

.. autoapifunction:: transformer_engine.pytorch.restore_from_saved