Unverified Commit 10dd8226 authored by ver217's avatar ver217 Committed by GitHub
Browse files

add gather_output for VocabParallelClassifier1D (#1569)

parent e615cfc3
...@@ -283,11 +283,13 @@ class VocabParallelClassifier1D(ParallelLayer): ...@@ -283,11 +283,13 @@ class VocabParallelClassifier1D(ParallelLayer):
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: torch.dtype = None, dtype: torch.dtype = None,
gather_output: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.num_classes = num_classes self.num_classes = num_classes
self.gather_output = gather_output
self.parallel_input = get_parallel_input() self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
...@@ -382,7 +384,12 @@ class VocabParallelClassifier1D(ParallelLayer): ...@@ -382,7 +384,12 @@ class VocabParallelClassifier1D(ParallelLayer):
# Set up backprop all-reduce. # Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D) input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply. # Matrix multiply.
output = F.linear(input_parallel, self.weight, self.bias) output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
else:
output = output_parallel
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment