Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
2ad895a6
Unverified
Commit
2ad895a6
authored
Feb 01, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 01, 2023
Browse files
feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)
parent
404ed7a1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
19 deletions
+27
-19
README.md
README.md
+1
-1
server/text_generation/models/gpt_neox.py
server/text_generation/models/gpt_neox.py
+25
-17
server/text_generation/utils.py
server/text_generation/utils.py
+1
-1
No files found.
README.md
View file @
2ad895a6
...
...
@@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
~~
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
~~ (deactivated)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
GPT-Neox 20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
: use
`--revision
refs/
pr/13`
-
[
GPT-Neox 20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
: use
`--revision pr/13`
Other models are supported on a best effort basis using:
...
...
server/text_generation/models/gpt_neox.py
View file @
2ad895a6
...
...
@@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
name
==
"embed_out.weight"
:
elif
name
==
"embed_out.weight"
and
model
.
gpt_neox
.
tp_embeddings
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
...
...
@@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
)
if
(
type
(
module
)
in
[
TensorParallelRowLinear
,
TensorParallelColumnLinear
]
and
param_name
==
"weight"
type
(
module
)
in
[
TensorParallelRowLinear
,
TensorParallelColumnLinear
]
and
param_name
==
"weight"
):
tensor
=
Int8Params
(
tensor
,
...
...
@@ -229,16 +229,24 @@ class GPTNeoxSharded(GPTNeox):
def
forward
(
self
,
input_ids
,
attention_mask
,
position_ids
,
past_key_values
:
Optional
=
None
):
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
# Logits are sharded, so we need to gather them
logits
=
[
torch
.
empty_like
(
outputs
.
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
outputs
.
logits
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
2
)
return
logits
,
outputs
.
past_key_values
if
self
.
model
.
gpt_neox
.
tp_embeddings
:
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
# Logits are sharded, so we need to gather them
logits
=
[
torch
.
empty_like
(
outputs
.
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
outputs
.
logits
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
2
)
return
logits
,
outputs
.
past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else
:
return
super
(
GPTNeoxSharded
,
self
).
forward
(
input_ids
,
attention_mask
,
position_ids
,
past_key_values
)
server/text_generation/utils.py
View file @
2ad895a6
...
...
@@ -91,7 +91,7 @@ class NextTokenChooser:
top_p
=
pb
.
top_p
,
do_sample
=
pb
.
do_sample
,
seed
=
pb
.
seed
,
device
=
str
(
device
)
,
device
=
device
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment