Unverified Commit 0963020f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Improve documentation (#478)



Improve docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2c410836
......@@ -6,8 +6,6 @@
C/C++ API
=========
.. Caution:: This feature is not officially supported yet and may change without notice.
The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
directly from C/C++, without Python.
......
......@@ -7,26 +7,26 @@ pyTorch
=======
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward
:members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
:members: forward, set_context_parallel_group
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict
......
......@@ -1920,7 +1920,19 @@ class DotProductAttention(torch.nn.Module):
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
......@@ -2560,7 +2572,15 @@ class MultiheadAttention(torch.nn.Module):
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self.tp_group = tp_group
def set_context_parallel_group(
......@@ -2569,7 +2589,19 @@ class MultiheadAttention(torch.nn.Module):
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
......
......@@ -467,7 +467,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group."""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self.tp_group = tp_group
self.tp_group_initialized = True
......
......@@ -425,7 +425,15 @@ class TransformerLayer(torch.nn.Module):
)
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group"""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
......@@ -439,7 +447,19 @@ class TransformerLayer(torch.nn.Module):
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment