Unverified Commit 2ad895a6 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

parent 404ed7a1
...@@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets. ...@@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl) - [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated) - ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder) - [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13` - [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
Other models are supported on a best effort basis using: Other models are supported on a best effort basis using:
......
...@@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox): ...@@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
start = rank * block_size start = rank * block_size
stop = (rank + 1) * block_size stop = (rank + 1) * block_size
tensor = slice_[start:stop] tensor = slice_[start:stop]
elif name == "embed_out.weight": elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0] size = slice_.get_shape()[0]
block_size = size // world_size block_size = size // world_size
start = rank * block_size start = rank * block_size
...@@ -229,6 +229,7 @@ class GPTNeoxSharded(GPTNeox): ...@@ -229,6 +229,7 @@ class GPTNeoxSharded(GPTNeox):
def forward( def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
): ):
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward( outputs = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -238,7 +239,14 @@ class GPTNeoxSharded(GPTNeox): ...@@ -238,7 +239,14 @@ class GPTNeoxSharded(GPTNeox):
# Logits are sharded, so we need to gather them # Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)] logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group) torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2) logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)
...@@ -91,7 +91,7 @@ class NextTokenChooser: ...@@ -91,7 +91,7 @@ class NextTokenChooser:
top_p=pb.top_p, top_p=pb.top_p,
do_sample=pb.do_sample, do_sample=pb.do_sample,
seed=pb.seed, seed=pb.seed,
device=str(device), device=device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment