Unverified Commit 6806998b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Fix embedding to support 2D inputs (#5829)

parent 515080ad
......@@ -310,7 +310,7 @@ class VocabParallelEmbedding(torch.nn.Module):
output_parallel = F.embedding(masked_input.long(), self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment