[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear (#590)
* Refactor parameter split in Linear module Remove module state from noop_cat. Support arbitrary names in parameter split. Handle tensor parallelism. Signed-off-by:Tim Moon <tmoon@nvidia.com> * Make noop_cat a standalone operation Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Update parameter splits in LayerNormLinear Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug case without bias Fix pylint complaints. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Remove unused import Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
Showing
Please register or sign in to comment