Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
439fcaf8
Unverified
Commit
439fcaf8
authored
Feb 16, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 16, 2023
Browse files
feat(router): add prometheus metrics scrape endpoint (#71)
parent
7b3d460d
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
239 additions
and
33 deletions
+239
-33
Cargo.lock
Cargo.lock
+120
-0
router/Cargo.toml
router/Cargo.toml
+2
-0
router/src/infer.rs
router/src/infer.rs
+43
-11
router/src/queue.rs
router/src/queue.rs
+3
-0
router/src/server.rs
router/src/server.rs
+58
-15
router/src/validation.rs
router/src/validation.rs
+8
-4
server/text_generation/utils/convert.py
server/text_generation/utils/convert.py
+2
-0
server/text_generation/utils/hub.py
server/text_generation/utils/hub.py
+3
-3
No files found.
Cargo.lock
View file @
439fcaf8
...
...
@@ -8,6 +8,17 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.20"
...
...
@@ -806,6 +817,9 @@ name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
dependencies = [
"ahash",
]
[[package]]
name = "heck"
...
...
@@ -1093,6 +1107,15 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "mach"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
dependencies = [
"libc",
]
[[package]]
name = "macro_rules_attribute"
version = "0.1.3"
...
...
@@ -1139,6 +1162,64 @@ dependencies = [
"autocfg",
]
[[package]]
name = "metrics"
version = "0.20.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849"
dependencies = [
"ahash",
"metrics-macros",
"portable-atomic",
]
[[package]]
name = "metrics-exporter-prometheus"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8603921e1f54ef386189335f288441af761e0fc61bcb552168d9cedfe63ebc70"
dependencies = [
"hyper",
"indexmap",
"ipnet",
"metrics",
"metrics-util",
"parking_lot",
"portable-atomic",
"quanta",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "metrics-macros"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "metrics-util"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7d24dc2dbae22bff6f1f9326ffce828c9f07ef9cc1e8002e5279f845432a30a"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
"hashbrown",
"metrics",
"num_cpus",
"parking_lot",
"portable-atomic",
"quanta",
"sketches-ddsketch",
]
[[package]]
name = "mime"
version = "0.3.16"
...
...
@@ -1514,6 +1595,12 @@ version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "portable-atomic"
version = "0.3.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26f6a7b87c2e435a3241addceeeff740ff8b7e76b74c13bf9acb17fa454ea00b"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
...
...
@@ -1618,6 +1705,22 @@ dependencies = [
"prost",
]
[[package]]
name = "quanta"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e31331286705f455e56cca62e0e717158474ff02b7936c1fa596d983f4ae27"
dependencies = [
"crossbeam-utils",
"libc",
"mach",
"once_cell",
"raw-cpuid",
"wasi 0.10.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]]
name = "quote"
version = "1.0.23"
...
...
@@ -1657,6 +1760,15 @@ dependencies = [
"getrandom",
]
[[package]]
name = "raw-cpuid"
version = "10.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c307f7aacdbab3f0adee67d52739a1d71112cc068d6fab169ddeb18e48877fad"
dependencies = [
"bitflags",
]
[[package]]
name = "rayon"
version = "1.6.1"
...
...
@@ -1980,6 +2092,12 @@ dependencies = [
"libc",
]
[[package]]
name = "sketches-ddsketch"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ceb945e54128e09c43d8e4f1277851bd5044c6fc540bbaa2ad888f60b3da9ae7"
[[package]]
name = "slab"
version = "0.4.7"
...
...
@@ -2143,6 +2261,8 @@ dependencies = [
"axum-tracing-opentelemetry",
"clap 4.1.4",
"futures",
"metrics",
"metrics-exporter-prometheus",
"nohash-hasher",
"opentelemetry",
"opentelemetry-otlp",
...
...
router/Cargo.toml
View file @
439fcaf8
...
...
@@ -19,6 +19,8 @@ axum-tracing-opentelemetry = "0.9.0"
text-generation-client
=
{
path
=
"client"
}
clap
=
{
version
=
"4.1.4"
,
features
=
[
"derive"
,
"env"
]
}
futures
=
"0.3.26"
metrics
=
"0.20.1"
metrics-exporter-prometheus
=
{
version
=
"0.11.0"
,
features
=
[]
}
nohash-hasher
=
"0.2.0"
opentelemetry
=
{
version
=
"0.18.0"
,
features
=
["rt-tokio"]
}
opentelemetry-otlp
=
"0.11.0"
...
...
router/src/infer.rs
View file @
439fcaf8
...
...
@@ -3,7 +3,6 @@ use crate::validation::{Validation, ValidationError};
use
crate
::
GenerateRequest
;
use
crate
::{
Entry
,
Queue
,
Token
};
use
nohash_hasher
::
IntMap
;
use
std
::
future
::
Future
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
GeneratedText
,
Generation
,
PrefillTokens
,
ShardedClient
,
...
...
@@ -81,6 +80,7 @@ impl Infer {
.limit_concurrent_requests
.try_acquire_owned
()
.map_err
(|
err
|
{
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"overloaded"
);
tracing
::
error!
(
"{err}"
);
err
})
?
;
...
...
@@ -172,6 +172,7 @@ impl Infer {
})
}
else
{
let
err
=
InferError
::
IncompleteGeneration
;
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"incomplete"
);
tracing
::
error!
(
"{err}"
);
Err
(
err
)
}
...
...
@@ -201,7 +202,7 @@ async fn batching_task(
// This batch might be smaller than the maximum batch size if there are not enough requests
// waiting in the queue
while
let
Some
((
mut
entries
,
batch
,
span
))
=
queue
.next_batch
(
None
,
max_batch_size
)
.await
{
let
mut
cached_batch
=
wrap_future
(
client
.prefill
(
batch
)
,
&
mut
entries
)
let
mut
cached_batch
=
prefill
(
&
mut
client
,
batch
,
&
mut
entries
)
.instrument
(
span
)
.await
;
let
mut
waiting_tokens
=
1
;
...
...
@@ -212,6 +213,7 @@ async fn batching_task(
// Get current batch info
let
batch_size
=
batch
.size
;
let
mut
batches
=
vec!
[
batch
];
metrics
::
gauge!
(
"tgi_batch_current_size"
,
batch_size
as
f64
);
// If the current batch is too small, we try to add more requests to it
if
batch_size
<=
limit_min_batch_size
{
...
...
@@ -241,10 +243,9 @@ async fn batching_task(
});
// Generate one token for this new batch to have the attention past in cache
let
new_cached_batch
=
wrap_future
(
client
.prefill
(
new_batch
),
&
mut
new_entries
)
.instrument
(
span
)
.await
;
let
new_cached_batch
=
prefill
(
&
mut
client
,
new_batch
,
&
mut
new_entries
)
.instrument
(
span
)
.await
;
// Reset waiting counter
waiting_tokens
=
1
;
// Extend current batch with the new batch
...
...
@@ -268,29 +269,59 @@ async fn batching_task(
entry
.temp_span
=
Some
(
entry_batch_span
);
});
cached_batch
=
wrap_future
(
client
.decode
(
batches
)
,
&
mut
entries
)
cached_batch
=
decode
(
&
mut
client
,
batches
,
&
mut
entries
)
.instrument
(
next_batch_span
)
.await
;
waiting_tokens
+=
1
;
}
metrics
::
gauge!
(
"tgi_batch_current_size"
,
0.0
);
}
}
}
/// Wrap a future inside a match statement to handle errors and send the responses to Infer
#[instrument(skip_all)]
async
fn
wrap_future
(
future
:
impl
Future
<
Output
=
Result
<
(
Vec
<
Generation
>
,
Option
<
Batch
>
),
ClientError
>>
,
async
fn
prefill
(
client
:
&
mut
ShardedClient
,
batch
:
Batch
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
)
->
Option
<
Batch
>
{
match
future
.await
{
let
start_time
=
Instant
::
now
();
match
client
.prefill
(
batch
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
send_generations
(
generations
,
entries
);
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
(),
"method"
=>
"prefill"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"prefill"
);
next_batch
}
// If we have an error, we discard the whole batch
Err
(
err
)
=>
{
send_errors
(
err
,
entries
);
metrics
::
increment_counter!
(
"tgi_batch_inference_failure"
,
"method"
=>
"prefill"
);
None
}
}
}
#[instrument(skip_all)]
async
fn
decode
(
client
:
&
mut
ShardedClient
,
batches
:
Vec
<
Batch
>
,
entries
:
&
mut
IntMap
<
u64
,
Entry
>
,
)
->
Option
<
Batch
>
{
let
start_time
=
Instant
::
now
();
match
client
.decode
(
batches
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
send_generations
(
generations
,
entries
);
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
(),
"method"
=>
"decode"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"decode"
);
next_batch
}
// If we have an error, we discard the whole batch
Err
(
err
)
=>
{
send_errors
(
err
,
entries
);
metrics
::
increment_counter!
(
"tgi_batch_inference_failure"
,
"method"
=>
"decode"
);
None
}
}
...
...
@@ -303,6 +334,7 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
// Create and enter a span to link this function back to the entry
let
_
send_error_span
=
info_span!
(
parent
:
entry
.temp_span
.as_ref
()
.expect
(
"batch_span is None. This is a bug."
),
"send_error"
)
.entered
();
let
err
=
InferError
::
GenerationError
(
error
.to_string
());
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"generation"
);
tracing
::
error!
(
"{err}"
);
// unwrap_or is valid here as we don't care if the receiver is gone.
...
...
router/src/queue.rs
View file @
439fcaf8
...
...
@@ -132,6 +132,7 @@ impl State {
// Push entry in the queue
self
.entries
.push
((
self
.next_id
,
entry
));
self
.next_id
+=
1
;
metrics
::
increment_gauge!
(
"tgi_queue_size"
,
1.0
);
}
// Get the next batch
...
...
@@ -190,6 +191,8 @@ impl State {
// Increment batch id
self
.next_batch_id
+=
1
;
metrics
::
gauge!
(
"tgi_queue_size"
,
self
.entries
.len
()
as
f64
);
metrics
::
histogram!
(
"tgi_batch_next_size"
,
batch
.size
as
f64
);
Some
((
batch_entries
,
batch
,
next_batch_span
))
}
}
...
...
router/src/server.rs
View file @
439fcaf8
...
...
@@ -12,6 +12,7 @@ use axum::routing::{get, post};
use
axum
::{
Json
,
Router
};
use
axum_tracing_opentelemetry
::
opentelemetry_tracing_layer
;
use
futures
::
Stream
;
use
metrics_exporter_prometheus
::{
PrometheusBuilder
,
PrometheusHandle
};
use
std
::
convert
::
Infallible
;
use
std
::
net
::
SocketAddr
;
use
text_generation_client
::
ShardedClient
;
...
...
@@ -57,14 +58,14 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
path
=
"/generate"
,
request_body
=
GenerateRequest,
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
[
GenerateResponse
]
),
(
status
=
424
,
description
=
"Generation Error"
,
body
=
[
ErrorResponse
]
,
(status
=
200
,
description
=
"Generated Text"
,
body
=
GenerateResponse),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
)),
(
status
=
429
,
description
=
"Model is overloaded"
,
body
=
[
ErrorResponse
]
,
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
)),
(
status
=
422
,
description
=
"Input validation error"
,
body
=
[
ErrorResponse
]
,
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
)),
(
status
=
500
,
description
=
"Incomplete generation"
,
body
=
[
ErrorResponse
]
,
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
)),
)
)]
...
...
@@ -141,6 +142,18 @@ async fn generate(
span
.record
(
"seed"
,
format!
(
"{:?}"
,
response
.generated_text.seed
));
tracing
::
info!
(
"Output: {}"
,
response
.generated_text.text
);
// Metrics
metrics
::
increment_counter!
(
"tgi_request_success"
);
metrics
::
histogram!
(
"tgi_request_duration"
,
total_time
);
metrics
::
histogram!
(
"tgi_request_validation_duration"
,
validation_time
);
metrics
::
histogram!
(
"tgi_request_queue_duration"
,
queue_time
);
metrics
::
histogram!
(
"tgi_request_inference_duration"
,
inference_time
);
metrics
::
histogram!
(
"tgi_request_mean_time_per_token_duration"
,
time_per_token
);
metrics
::
histogram!
(
"tgi_request_generated_tokens"
,
response
.generated_text.generated_tokens
as
f64
);
// Send response
let
response
=
GenerateResponse
{
generated_text
:
response
.generated_text.text
,
...
...
@@ -156,20 +169,20 @@ async fn generate(
path
=
"/generate_stream"
,
request_body
=
GenerateRequest,
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
[
StreamResponse
]
,
content_type
=
"text/event-stream
"
),
(
status
=
424
,
description
=
"Generation Error"
,
body
=
[
ErrorResponse
]
,
(status
=
200
,
description
=
"Generated Text"
,
body
=
StreamResponse,
content_type=
"text/event-stream"
),
(status
=
424
,
description
=
"Generation Error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Request failed during generation"
}
),
content_type
=
"text/event-stream
"
),
(
status
=
429
,
description
=
"Model is overloaded"
,
body
=
[
ErrorResponse
]
,
content_type=
"text/event-stream"
),
(status
=
429
,
description
=
"Model is overloaded"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Model is overloaded"
}
),
content_type
=
"text/event-stream
"
),
(
status
=
422
,
description
=
"Input validation error"
,
body
=
[
ErrorResponse
]
,
content_type=
"text/event-stream"
),
(status
=
422
,
description
=
"Input validation error"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Input validation error"
}
),
content_type
=
"text/event-stream
"
),
(
status
=
500
,
description
=
"Incomplete generation"
,
body
=
[
ErrorResponse
]
,
content_type=
"text/event-stream"
),
(status
=
500
,
description
=
"Incomplete generation"
,
body
=
ErrorResponse,
example
=
json
!
(
{
"error"
:
"Incomplete generation"
}
),
content_type
=
"text/event-stream
"
),
content_type=
"text/event-stream"
),
)
)]
#[instrument(
...
...
@@ -249,6 +262,15 @@ async fn generate_stream(
span
.record
(
"seed"
,
format!
(
"{:?}"
,
generated_text
.seed
));
tracing
::
info!
(
parent
:
&
span
,
"Output: {}"
,
generated_text
.text
);
// Metrics
metrics
::
increment_counter!
(
"tgi_request_success"
);
metrics
::
histogram!
(
"tgi_request_duration"
,
total_time
);
metrics
::
histogram!
(
"tgi_request_validation_duration"
,
validation_time
);
metrics
::
histogram!
(
"tgi_request_queue_duration"
,
queue_time
);
metrics
::
histogram!
(
"tgi_request_inference_duration"
,
inference_time
);
metrics
::
histogram!
(
"tgi_request_mean_time_per_token_duration"
,
time_per_token
);
metrics
::
histogram!
(
"tgi_request_generated_tokens"
,
generated_text
.generated_tokens
as
f64
);
// StreamResponse
end_reached
=
true
;
let
stream_token
=
StreamResponse
{
...
...
@@ -279,6 +301,7 @@ async fn generate_stream(
// Skip if we already sent an error
if
!
end_reached
&&
!
error
{
let
err
=
InferError
::
IncompleteGeneration
;
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"incomplete"
);
tracing
::
error!
(
"{err}"
);
yield
Ok
(
Event
::
from
(
err
))
}
...
...
@@ -287,6 +310,17 @@ async fn generate_stream(
Sse
::
new
(
stream
)
.keep_alive
(
KeepAlive
::
default
())
}
/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
tag
=
"Text Generation Inference"
,
path
=
"/metrics"
,
responses((status
=
200
,
description
=
"Prometheus Metrics"
,
body
=
String))
)]
async
fn
metrics
(
prom_handle
:
Extension
<
PrometheusHandle
>
)
->
String
{
prom_handle
.render
()
}
/// Serving method
#[allow(clippy::too_many_arguments)]
pub
async
fn
run
(
...
...
@@ -307,6 +341,7 @@ pub async fn run(
paths(
generate,
generate_stream,
metrics,
),
components(
schemas(
...
...
@@ -350,6 +385,12 @@ pub async fn run(
max_concurrent_requests
,
);
// Prometheus handler
let
builder
=
PrometheusBuilder
::
new
();
let
prom_handle
=
builder
.install_recorder
()
.expect
(
"failed to install metrics recorder"
);
// Create router
let
app
=
Router
::
new
()
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
...
...
@@ -359,6 +400,8 @@ pub async fn run(
.route
(
"/"
,
get
(
health
))
.route
(
"/health"
,
get
(
health
))
.layer
(
Extension
(
infer
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
prom_handle
))
.layer
(
opentelemetry_tracing_layer
());
// Run server
...
...
router/src/validation.rs
View file @
439fcaf8
...
...
@@ -13,7 +13,7 @@ use tracing::{instrument, Span};
#[derive(Debug,
Clone)]
pub
struct
Validation
{
/// Channel to communicate with the background validation task
sender
:
mpsc
::
Sender
<
ValidationRequest
>
,
sender
:
mpsc
::
Unbounded
Sender
<
ValidationRequest
>
,
}
impl
Validation
{
...
...
@@ -25,7 +25,7 @@ impl Validation {
max_total_tokens
:
usize
,
)
->
Self
{
// Create channel
let
(
validation_sender
,
validation_receiver
)
=
mpsc
::
channel
(
128
);
let
(
validation_sender
,
validation_receiver
)
=
mpsc
::
unbounded_
channel
();
// Launch background validation task
tokio
::
spawn
(
validation_task
(
...
...
@@ -54,7 +54,6 @@ impl Validation {
// Unwrap is safe here
self
.sender
.send
((
request
,
sender
,
Span
::
current
()))
.await
.unwrap
();
// Await on response channel
// Unwrap is safe here
...
...
@@ -70,7 +69,7 @@ async fn validation_task(
max_stop_sequences
:
usize
,
max_input_length
:
usize
,
max_total_tokens
:
usize
,
mut
receiver
:
mpsc
::
Receiver
<
ValidationRequest
>
,
mut
receiver
:
mpsc
::
Unbounded
Receiver
<
ValidationRequest
>
,
)
{
let
mut
workers_senders
=
Vec
::
with_capacity
(
workers
);
...
...
@@ -131,6 +130,7 @@ fn validation_worker(
&
mut
rng
,
)
.map_err
(|
err
|
{
metrics
::
increment_counter!
(
"tgi_request_failure"
,
"err"
=>
"validation"
);
tracing
::
error!
(
"{err}"
);
err
}),
...
...
@@ -214,6 +214,7 @@ fn validate(
Ok
(
encoding
)
=>
{
let
input_length
=
encoding
.len
();
let
total_tokens
=
input_length
+
max_new_tokens
as
usize
;
if
input_length
>
max_input_length
{
Err
(
ValidationError
::
InputLength
(
max_input_length
,
input_length
))
}
else
if
total_tokens
>
max_total_tokens
{
...
...
@@ -237,6 +238,9 @@ fn validate(
stop_sequences
,
};
metrics
::
histogram!
(
"tgi_request_input_length"
,
input_length
as
f64
);
metrics
::
histogram!
(
"tgi_request_max_new_tokens"
,
max_new_tokens
as
f64
);
Ok
(
ValidGenerateRequest
{
inputs
:
request
.inputs
,
input_length
:
input_length
as
u32
,
...
...
server/text_generation/utils/convert.py
View file @
439fcaf8
...
...
@@ -49,6 +49,8 @@ def convert_file(pt_file: Path, st_file: Path):
"""
Convert a pytorch file to a safetensors file
"""
logger
.
info
(
f
"Convert
{
pt_file
}
to
{
st_file
}
."
)
pt_state
=
torch
.
load
(
pt_file
,
map_location
=
"cpu"
)
if
"state_dict"
in
pt_state
:
pt_state
=
pt_state
[
"state_dict"
]
...
...
server/text_generation/utils/hub.py
View file @
439fcaf8
...
...
@@ -132,9 +132,9 @@ def download_weights(
local_file
=
try_to_load_from_cache
(
model_id
,
revision
,
filename
)
if
local_file
is
not
None
:
logger
.
info
(
f
"File
{
filename
}
already present in cache."
)
return
local_file
return
Path
(
local_file
)
logger
.
info
(
f
"
Starting
{
filename
}
download.
"
)
logger
.
info
(
f
"
Download file:
{
filename
}
"
)
start_time
=
time
.
time
()
local_file
=
hf_hub_download
(
filename
=
filename
,
...
...
@@ -143,7 +143,7 @@ def download_weights(
local_files_only
=
False
,
)
logger
.
info
(
f
"Downloaded
{
filename
}
at
{
local_file
}
in
{
timedelta
(
seconds
=
int
(
time
.
time
()
-
start_time
))
}
."
f
"Downloaded
{
local_file
}
in
{
timedelta
(
seconds
=
int
(
time
.
time
()
-
start_time
))
}
."
)
return
Path
(
local_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment