Unverified Commit 0963020f authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Improve documentation (#478)



Improve docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 2c410836
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
C/C++ API C/C++ API
========= =========
.. Caution:: This feature is not officially supported yet and may change without notice.
The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
directly from C/C++, without Python. directly from C/C++, without Python.
......
...@@ -7,26 +7,26 @@ pyTorch ...@@ -7,26 +7,26 @@ pyTorch
======= =======
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs) .. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward :members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs) .. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs) .. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward :members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs) .. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward :members: forward, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs) .. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward :members: forward, set_context_parallel_group
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs) .. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward :members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward :members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length) .. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict :members: swap_key_value_dict
......
...@@ -1920,7 +1920,19 @@ class DotProductAttention(torch.nn.Module): ...@@ -1920,7 +1920,19 @@ class DotProductAttention(torch.nn.Module):
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group""" """
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
self.cp_group = cp_group self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
...@@ -2560,7 +2572,15 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2560,7 +2572,15 @@ class MultiheadAttention(torch.nn.Module):
) )
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group""" """
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self.tp_group = tp_group self.tp_group = tp_group
def set_context_parallel_group( def set_context_parallel_group(
...@@ -2569,7 +2589,19 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2569,7 +2589,19 @@ class MultiheadAttention(torch.nn.Module):
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group""" """
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
......
...@@ -467,7 +467,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -467,7 +467,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
) )
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group.""" """
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self.tp_group = tp_group self.tp_group = tp_group
self.tp_group_initialized = True self.tp_group_initialized = True
......
...@@ -425,7 +425,15 @@ class TransformerLayer(torch.nn.Module): ...@@ -425,7 +425,15 @@ class TransformerLayer(torch.nn.Module):
) )
def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
"""Set TP group""" """
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
...@@ -439,7 +447,19 @@ class TransformerLayer(torch.nn.Module): ...@@ -439,7 +447,19 @@ class TransformerLayer(torch.nn.Module):
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group""" """
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment