Unverified Commit 3fc8a204 authored by アマデウス's avatar アマデウス Committed by GitHub
Browse files

[]Corrected 3d vocab parallel embedding (#707)

parent ee112fe1
......@@ -525,7 +525,7 @@ class VocabParallelClassifier3D(ParallelLayer):
def _set_tensor_parallel_attributes(self) -> None:
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
......@@ -1048,7 +1048,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**3)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment