Commit 73a4144b authored by ver217's avatar ver217 Committed by Hongxin Liu
Browse files

[shardformer] fix embedding

parent 92230226
...@@ -214,6 +214,9 @@ class VocabParallelEmbedding1D(ParallelModule): ...@@ -214,6 +214,9 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
# padding index
self.padding_idx = self._select_padding_idx(padding_idx)
# offset the seed with randomizer index and rank # offset the seed with randomizer index and rank
seed = torch.random.initial_seed() seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment