Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0963020f
Unverified
Commit
0963020f
authored
Oct 16, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Oct 16, 2023
Browse files
Improve documentation (#478)
Improve docs Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
2c410836
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
72 additions
and
14 deletions
+72
-14
docs/api/c/index.rst
docs/api/c/index.rst
+0
-2
docs/api/pytorch.rst
docs/api/pytorch.rst
+6
-6
transformer_engine/pytorch/attention.py
transformer_engine/pytorch/attention.py
+35
-3
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+9
-1
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+22
-2
No files found.
docs/api/c/index.rst
View file @
0963020f
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
C/C++ API
C/C++ API
=========
=========
.. Caution:: This feature is not officially supported yet and may change without notice.
The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
The C/C++ API allows you to access the custom kernels defined in `libtransformer_engine.so` library
directly from C/C++, without Python.
directly from C/C++, without Python.
...
...
docs/api/pytorch.rst
View file @
0963020f
...
@@ -7,26 +7,26 @@ pyTorch
...
@@ -7,26 +7,26 @@ pyTorch
=======
=======
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.Linear(in_features, out_features, bias=True, **kwargs)
:members: forward
:members: forward
, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.RMSNorm(hidden_size, eps=1e-5, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormLinear(in_features, out_features, eps=1e-5, bias=True, **kwargs)
:members: forward
:members: forward
, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, bias=True, **kwargs)
:members: forward
:members: forward
, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward
:members: forward
, set_context_parallel_group
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward
:members: forward
, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward
:members: forward
, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length)
:members: swap_key_value_dict
:members: swap_key_value_dict
...
...
transformer_engine/pytorch/attention.py
View file @
0963020f
...
@@ -1920,7 +1920,19 @@ class DotProductAttention(torch.nn.Module):
...
@@ -1920,7 +1920,19 @@ class DotProductAttention(torch.nn.Module):
cp_global_ranks
:
List
[
int
],
cp_global_ranks
:
List
[
int
],
cp_stream
:
torch
.
cuda
.
Stream
,
cp_stream
:
torch
.
cuda
.
Stream
,
)
->
None
:
)
->
None
:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
self
.
cp_group
=
cp_group
self
.
cp_group
=
cp_group
self
.
cp_global_ranks
=
cp_global_ranks
self
.
cp_global_ranks
=
cp_global_ranks
self
.
cp_stream
=
cp_stream
self
.
cp_stream
=
cp_stream
...
@@ -2560,7 +2572,15 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -2560,7 +2572,15 @@ class MultiheadAttention(torch.nn.Module):
)
)
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
"""Set TP group"""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self
.
tp_group
=
tp_group
self
.
tp_group
=
tp_group
def
set_context_parallel_group
(
def
set_context_parallel_group
(
...
@@ -2569,7 +2589,19 @@ class MultiheadAttention(torch.nn.Module):
...
@@ -2569,7 +2589,19 @@ class MultiheadAttention(torch.nn.Module):
cp_global_ranks
:
List
[
int
],
cp_global_ranks
:
List
[
int
],
cp_stream
:
torch
.
cuda
.
Stream
,
cp_stream
:
torch
.
cuda
.
Stream
,
)
->
None
:
)
->
None
:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion.
# Deep iterate but skip self to avoid infinite recursion.
for
index
,
child
in
enumerate
(
self
.
modules
()):
for
index
,
child
in
enumerate
(
self
.
modules
()):
if
index
==
0
:
if
index
==
0
:
...
...
transformer_engine/pytorch/module/base.py
View file @
0963020f
...
@@ -467,7 +467,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -467,7 +467,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
)
)
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
"""Set TP group."""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
self
.
tp_group
=
tp_group
self
.
tp_group
=
tp_group
self
.
tp_group_initialized
=
True
self
.
tp_group_initialized
=
True
...
...
transformer_engine/pytorch/transformer.py
View file @
0963020f
...
@@ -425,7 +425,15 @@ class TransformerLayer(torch.nn.Module):
...
@@ -425,7 +425,15 @@ class TransformerLayer(torch.nn.Module):
)
)
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
def
set_tensor_parallel_group
(
self
,
tp_group
:
Union
[
dist_group_type
,
None
])
->
None
:
"""Set TP group"""
"""
Set the tensor parallel group for the given
module before executing the forward pass.
Parameters
----------
tp_group : ProcessGroup, default = `None`
tensor parallel process group.
"""
# Deep iterate but skip self to avoid infinite recursion.
# Deep iterate but skip self to avoid infinite recursion.
for
index
,
child
in
enumerate
(
self
.
modules
()):
for
index
,
child
in
enumerate
(
self
.
modules
()):
if
index
==
0
:
if
index
==
0
:
...
@@ -439,7 +447,19 @@ class TransformerLayer(torch.nn.Module):
...
@@ -439,7 +447,19 @@ class TransformerLayer(torch.nn.Module):
cp_global_ranks
:
List
[
int
],
cp_global_ranks
:
List
[
int
],
cp_stream
:
torch
.
cuda
.
Stream
,
cp_stream
:
torch
.
cuda
.
Stream
,
)
->
None
:
)
->
None
:
"""Set CP group"""
"""
Set the context parallel attributes for the given
module before executing the forward pass.
Parameters
----------
cp_group : ProcessGroup
context parallel process group.
cp_global_ranks : List[int]
list of global ranks in the context group.
cp_stream : torch.cuda.Stream
cuda stream for context parallel execution.
"""
# Deep iterate but skip self to avoid infinite recursion.
# Deep iterate but skip self to avoid infinite recursion.
for
index
,
child
in
enumerate
(
self
.
modules
()):
for
index
,
child
in
enumerate
(
self
.
modules
()):
if
index
==
0
:
if
index
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment