-
Eric Harper authored
* use _all_gather_base Signed-off-by:
ericharper <complex451@gmail.com> * use _reduce_scatter_base Signed-off-by:
ericharper <complex451@gmail.com> * remove torch empty in backward Signed-off-by:
ericharper <complex451@gmail.com> * check self.attn_mask_type Signed-off-by:
ericharper <complex451@gmail.com> * remove extra arg Signed-off-by:
ericharper <complex451@gmail.com> * update get_tensor_shapes logic Signed-off-by:
ericharper <complex451@gmail.com>
3c19f106