Commit 666cfd09 authored by Wesley's avatar Wesley Committed by アマデウス
Browse files

fix parallel_input flag for Linear1D_Col gather_output

parent a9f778f1
...@@ -302,7 +302,10 @@ class Linear1D_Col(ParallelLayer): ...@@ -302,7 +302,10 @@ class Linear1D_Col(ParallelLayer):
with seed(ParallelMode.TENSOR): with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(True) if self.gather_output:
set_parallel_input(False)
else:
set_parallel_input(True)
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features fan_in, fan_out = self.in_features, self.out_features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment