Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
e496c9ba
Unverified
Commit
e496c9ba
authored
Jun 09, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 09, 2023
Browse files
feat(server): optimize dist ops (#434)
parent
abd58ff8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
7 deletions
+38
-7
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+2
-1
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
...ration_server/models/custom_modeling/flash_rw_modeling.py
+4
-2
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+3
-1
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+29
-3
No files found.
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
e496c9ba
...
...
@@ -265,6 +265,7 @@ class FlashNeoXLayer(nn.Module):
mlp_output
=
self
.
mlp
(
ln2_hidden_states
)
intermediate
=
mlp_output
+
attn_output
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
intermediate
,
group
=
self
.
process_group
)
return
intermediate
+
hidden_states
,
None
...
...
server/text_generation_server/models/custom_modeling/flash_rw_modeling.py
View file @
e496c9ba
...
...
@@ -440,6 +440,7 @@ class FlashRWLayer(nn.Module):
mlp_output
=
self
.
mlp
(
ln_hidden_states
)
intermediate
=
mlp_output
+
attn_output
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
intermediate
,
group
=
self
.
process_group
)
return
intermediate
,
residual
...
...
@@ -524,6 +525,7 @@ class FlashRWLargeLayer(nn.Module):
intermediate
=
attn_output
+
mlp_output
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
intermediate
,
group
=
self
.
process_group
)
return
intermediate
,
residual
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
e496c9ba
...
...
@@ -346,6 +346,8 @@ class FlashSantacoderModel(nn.Module):
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
self
.
process_group
)
# Prefill
...
...
server/text_generation_server/utils/layers.py
View file @
e496c9ba
...
...
@@ -158,8 +158,33 @@ class TensorParallelHead(SuperLayer):
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
world_size
=
self
.
process_group
.
size
()
if
world_size
==
1
:
return
super
().
forward
(
input
)
if
len
(
input
.
shape
)
==
2
and
isinstance
(
self
.
linear
,
FastLinear
):
out_dim
=
self
.
linear
.
weight
.
shape
[
0
]
if
input
.
shape
[
0
]
==
1
:
world_out
=
input
.
new_empty
(
1
,
out_dim
*
world_size
)
local_out
=
input
.
new_empty
(
1
,
out_dim
)
gather_input
=
local_out
else
:
world_out
=
input
.
new_empty
(
out_dim
*
world_size
,
input
.
shape
[
0
])
gather_input
=
input
.
new_empty
(
out_dim
,
input
.
shape
[
0
])
local_out
=
gather_input
.
T
torch
.
mm
(
input
,
self
.
linear
.
weight
.
T
,
out
=
local_out
)
torch
.
distributed
.
all_gather_into_tensor
(
world_out
,
gather_input
,
group
=
self
.
process_group
)
if
input
.
shape
[
0
]
==
1
:
return
world_out
return
world_out
.
T
output
=
super
().
forward
(
input
)
# Logits are sharded, so we need to gather them
world_output
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
self
.
process_group
.
size
())
]
...
...
@@ -211,6 +236,7 @@ class TensorParallelRowLinear(SuperLayer):
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
super
().
forward
(
input
)
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
...
...
@@ -245,7 +271,7 @@ class TensorParallelEmbedding(nn.Module):
input
-
self
.
min_id
,
)
out
=
torch
.
nn
.
functional
.
embedding
(
input
,
self
.
weight
)
if
self
.
reduce
:
if
self
.
reduce
and
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment