Unverified Commit 0426feb6 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Consistent docs for fuse_wgrad_accumulation (#289)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 918a9ad7
......@@ -593,7 +593,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
......
......@@ -906,7 +906,10 @@ class LayerNormMLP(TransformerEngineBaseModule):
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
......
......@@ -161,7 +161,10 @@ class TransformerLayer(torch.nn.Module):
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment