Return layernorm output in the gathered form (#697)
* first draft of return_layernorm_output_gathered Signed-off-by:Chen Cui <chcui@nvidia.com> * explain use case more thoroughly in docstring Signed-off-by:
Chen Cui <chcui@nvidia.com> * add same option in `LayerNormMLP` Signed-off-by:
Chen Cui <chcui@nvidia.com> * Update transformer_engine/pytorch/module/layernorm_linear.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Chen Cui <cxcui@alumni.cmu.edu> * Update transformer_engine/pytorch/module/layernorm_linear.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Chen Cui <cxcui@alumni.cmu.edu> * address comments Signed-off-by:
Chen Cui <chcui@nvidia.com> * add same option in LayerNormMLP Signed-off-by:
Chen Cui <chcui@nvidia.com> * address linter errors Signed-off-by:
Chen Cui <chcui@nvidia.com> --------- Signed-off-by:
Chen Cui <chcui@nvidia.com> Signed-off-by:
Chen Cui <cxcui@alumni.cmu.edu> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Showing
Please register or sign in to comment