pytorch.rst 2.76 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8

    See LICENSE for license information.

pyTorch
=======

9
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
10
  :members: forward, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
11

12
13
14
.. autoapiclass:: transformer_engine.pytorch.GroupedLinear(in_features, out_features, bias=True, **kwargs)
  :members: forward, set_tensor_parallel_group

15
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
16

17
18
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)

19
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
20
  :members: forward, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
21

22
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
23
  :members: forward, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
24

25
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
26
  :members: forward, set_context_parallel_group
cyanguwa's avatar
cyanguwa committed
27

28
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
29
  :members: forward, set_context_parallel_group, set_tensor_parallel_group
30

31
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
32
  :members: forward, set_context_parallel_group, set_tensor_parallel_group
Przemek Tredak's avatar
Przemek Tredak committed
33

34
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
35
  :members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
36

37
38
39
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
  :members: reset, get_states, set_states, add, fork

40
.. autoapifunction:: transformer_engine.pytorch.fp8_autocast
41

42
43
.. autoapifunction:: transformer_engine.pytorch.fp8_model_init

44
.. autoapifunction:: transformer_engine.pytorch.checkpoint
45

46
47
.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

48
.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
49
50
51

.. autoapifunction:: transformer_engine.pytorch.moe_permute

52
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
53

54
.. autoapifunction:: transformer_engine.pytorch.moe_unpermute
55

56
57
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index

58
59
.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy

60
61
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs

62
63
64
.. autoapifunction:: transformer_engine.pytorch.initialize_ub

.. autoapifunction:: transformer_engine.pytorch.destroy_ub
65
66
67

.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
  :members: FP8, NONE