Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
afd04dc7
Unverified
Commit
afd04dc7
authored
Jul 28, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 28, 2023
Browse files
feat(server): update vllm version (#723)
parent
f848dece
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
22 deletions
+21
-22
router/src/main.rs
router/src/main.rs
+5
-1
server/Makefile-vllm
server/Makefile-vllm
+1
-1
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+15
-20
No files found.
router/src/main.rs
View file @
afd04dc7
...
@@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> {
...
@@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> {
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
);
}
}
if
max_total_tokens
as
u32
>
max_supported_batch_total_tokens
{
return
Err
(
RouterError
::
ArgumentValidation
(
format!
(
"`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}"
)));
}
max_supported_batch_total_tokens
max_supported_batch_total_tokens
}
}
};
};
...
@@ -270,7 +274,7 @@ fn main() -> Result<(), RouterError> {
...
@@ -270,7 +274,7 @@ fn main() -> Result<(), RouterError> {
ngrok_authtoken
,
ngrok_authtoken
,
ngrok_edge
,
ngrok_edge
,
)
)
.await
?
;
.await
?
;
Ok
(())
Ok
(())
})
})
}
}
...
...
server/Makefile-vllm
View file @
afd04dc7
vllm_commit :=
d284b831c17f42a8ea63369a06138325f73c4cf9
vllm_commit :=
084ca75d4271f8f67be731bc58e0d41d8e0afd3a
vllm:
vllm:
# Clone vllm
# Clone vllm
...
...
server/text_generation_server/utils/layers.py
View file @
afd04dc7
...
@@ -219,36 +219,31 @@ class TensorParallelHead(SuperLayer):
...
@@ -219,36 +219,31 @@ class TensorParallelHead(SuperLayer):
)
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
self
.
should_gather
:
return
super
().
forward
(
input
)
world_size
=
self
.
process_group
.
size
()
world_size
=
self
.
process_group
.
size
()
if
len
(
input
.
shape
)
==
2
and
isinstance
(
self
.
linear
,
FastLinear
):
# Fast branch for single requests
if
(
self
.
should_gather
and
len
(
input
.
shape
)
==
2
and
isinstance
(
self
.
linear
,
FastLinear
)
and
input
.
shape
[
0
]
==
1
):
out_dim
=
self
.
linear
.
weight
.
shape
[
0
]
out_dim
=
self
.
linear
.
weight
.
shape
[
0
]
if
input
.
shape
[
0
]
==
1
:
world_out
=
input
.
new_empty
(
1
,
out_dim
*
world_size
)
world_out
=
input
.
new_empty
(
1
,
out_dim
*
world_size
)
local_out
=
input
.
new_empty
(
1
,
out_dim
)
local_out
=
input
.
new_empty
(
1
,
out_dim
)
gather_input
=
local_out
else
:
world_out
=
input
.
new_empty
(
out_dim
*
world_size
,
input
.
shape
[
0
])
gather_input
=
input
.
new_empty
(
out_dim
,
input
.
shape
[
0
])
local_out
=
gather_input
.
T
torch
.
mm
(
input
,
self
.
linear
.
weight
.
T
,
out
=
local_out
)
torch
.
mm
(
input
,
self
.
linear
.
weight
.
T
,
out
=
local_out
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
world_out
,
gather_inp
ut
,
group
=
self
.
process_group
world_out
,
local_o
ut
,
group
=
self
.
process_group
)
)
return
world_out
if
input
.
shape
[
0
]
==
1
:
return
world_out
return
world_out
.
T
output
=
super
().
forward
(
input
)
output
=
super
().
forward
(
input
)
world_output
=
[
if
not
self
.
should_gather
:
torch
.
empty_like
(
output
)
for
_
in
range
(
self
.
process_group
.
size
())
return
output
]
world_output
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
world_output
,
output
,
group
=
self
.
process_group
)
torch
.
distributed
.
all_gather
(
world_output
,
output
,
group
=
self
.
process_group
)
world_output
=
torch
.
cat
(
world_output
,
dim
=-
1
)
world_output
=
torch
.
cat
(
world_output
,
dim
=-
1
)
return
world_output
return
world_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment