Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0d97a93c
Unverified
Commit
0d97a93c
authored
Jul 01, 2024
by
drbh
Committed by
GitHub
Jul 01, 2024
Browse files
feat: download lora adapter weights from launcher (#2140)
parent
25f57e2e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
9 deletions
+37
-9
launcher/src/main.rs
launcher/src/main.rs
+37
-9
No files found.
launcher/src/main.rs
View file @
0d97a93c
...
@@ -898,13 +898,20 @@ enum LauncherError {
...
@@ -898,13 +898,20 @@ enum LauncherError {
WebserverCannotStart
,
WebserverCannotStart
,
}
}
fn
download_convert_model
(
args
:
&
Args
,
running
:
Arc
<
AtomicBool
>
)
->
Result
<
(),
LauncherError
>
{
fn
download_convert_model
(
model_id
:
&
str
,
revision
:
Option
<&
str
>
,
trust_remote_code
:
bool
,
huggingface_hub_cache
:
Option
<&
str
>
,
weights_cache_override
:
Option
<&
str
>
,
running
:
Arc
<
AtomicBool
>
,
)
->
Result
<
(),
LauncherError
>
{
// Enter download tracing span
// Enter download tracing span
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
let
_
span
=
tracing
::
span!
(
tracing
::
Level
::
INFO
,
"download"
)
.entered
();
let
mut
download_args
=
vec!
[
let
mut
download_args
=
vec!
[
"download-weights"
.to_string
(),
"download-weights"
.to_string
(),
args
.
model_id
.to_string
(),
model_id
.to_string
(),
"--extension"
.to_string
(),
"--extension"
.to_string
(),
".safetensors"
.to_string
(),
".safetensors"
.to_string
(),
"--logger-level"
.to_string
(),
"--logger-level"
.to_string
(),
...
@@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -913,13 +920,13 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
];
];
// Model optional revision
// Model optional revision
if
let
Some
(
revision
)
=
&
args
.
revision
{
if
let
Some
(
revision
)
=
&
revision
{
download_args
.push
(
"--revision"
.to_string
());
download_args
.push
(
"--revision"
.to_string
());
download_args
.push
(
revision
.to_string
())
download_args
.push
(
revision
.to_string
())
}
}
// Trust remote code for automatic peft fusion
// Trust remote code for automatic peft fusion
if
args
.
trust_remote_code
{
if
trust_remote_code
{
download_args
.push
(
"--trust-remote-code"
.to_string
());
download_args
.push
(
"--trust-remote-code"
.to_string
());
}
}
...
@@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -934,7 +941,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If huggingface_hub_cache is set, pass it to the download process
// If huggingface_hub_cache is set, pass it to the download process
// Useful when running inside a docker container
// Useful when running inside a docker container
if
let
Some
(
ref
huggingface_hub_cache
)
=
args
.
huggingface_hub_cache
{
if
let
Some
(
ref
huggingface_hub_cache
)
=
huggingface_hub_cache
{
envs
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
envs
.push
((
"HUGGINGFACE_HUB_CACHE"
.into
(),
huggingface_hub_cache
.into
()));
};
};
...
@@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -952,7 +959,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
// If args.weights_cache_override is some, pass it to the download process
// If args.weights_cache_override is some, pass it to the download process
// Useful when running inside a HuggingFace Inference Endpoint
// Useful when running inside a HuggingFace Inference Endpoint
if
let
Some
(
weights_cache_override
)
=
&
args
.
weights_cache_override
{
if
let
Some
(
weights_cache_override
)
=
&
weights_cache_override
{
envs
.push
((
envs
.push
((
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
"WEIGHTS_CACHE_OVERRIDE"
.into
(),
weights_cache_override
.into
(),
weights_cache_override
.into
(),
...
@@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -960,7 +967,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
};
};
// Start process
// Start process
tracing
::
info!
(
"Starting download process
.
"
);
tracing
::
info!
(
"Starting
check and
download process
for {model_id}
"
);
let
mut
download_process
=
match
Command
::
new
(
"text-generation-server"
)
let
mut
download_process
=
match
Command
::
new
(
"text-generation-server"
)
.args
(
download_args
)
.args
(
download_args
)
.env_clear
()
.env_clear
()
...
@@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
...
@@ -1002,7 +1009,7 @@ fn download_convert_model(args: &Args, running: Arc<AtomicBool>) -> Result<(), L
loop
{
loop
{
if
let
Some
(
status
)
=
download_process
.try_wait
()
.unwrap
()
{
if
let
Some
(
status
)
=
download_process
.try_wait
()
.unwrap
()
{
if
status
.success
()
{
if
status
.success
()
{
tracing
::
info!
(
"Successfully downloaded weights
.
"
);
tracing
::
info!
(
"Successfully downloaded weights
for {model_id}
"
);
break
;
break
;
}
}
...
@@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
...
@@ -1557,7 +1564,28 @@ fn main() -> Result<(), LauncherError> {
.expect
(
"Error setting Ctrl-C handler"
);
.expect
(
"Error setting Ctrl-C handler"
);
// Download and convert model weights
// Download and convert model weights
download_convert_model
(
&
args
,
running
.clone
())
?
;
download_convert_model
(
&
args
.model_id
,
args
.revision
.as_deref
(),
args
.trust_remote_code
,
args
.huggingface_hub_cache
.as_deref
(),
args
.weights_cache_override
.as_deref
(),
running
.clone
(),
)
?
;
// Download and convert lora adapters if any
if
let
Some
(
lora_adapters
)
=
&
args
.lora_adapters
{
for
adapter
in
lora_adapters
.split
(
','
)
{
download_convert_model
(
adapter
,
None
,
args
.trust_remote_code
,
args
.huggingface_hub_cache
.as_deref
(),
args
.weights_cache_override
.as_deref
(),
running
.clone
(),
)
?
;
}
}
if
!
running
.load
(
Ordering
::
SeqCst
)
{
if
!
running
.load
(
Ordering
::
SeqCst
)
{
// Launcher was asked to stop
// Launcher was asked to stop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment