Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0e9d249b
Unverified
Commit
0e9d249b
authored
Apr 29, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 29, 2023
Browse files
feat(benchmark): add support for private tokenizers (#262)
parent
b0b97fd9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
3 deletions
+15
-3
benchmark/src/main.rs
benchmark/src/main.rs
+15
-3
No files found.
benchmark/src/main.rs
View file @
0e9d249b
...
...
@@ -5,7 +5,7 @@
use
clap
::
Parser
;
use
std
::
path
::
Path
;
use
text_generation_client
::
ShardedClient
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
{
FromPretrainedParameters
,
Tokenizer
}
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::
EnvFilter
;
...
...
@@ -16,6 +16,8 @@ use tracing_subscriber::EnvFilter;
struct
Args
{
#[clap(short,
long,
env)]
tokenizer_name
:
String
,
#[clap(default_value
=
"main"
,
long,
env)]
revision
:
String
,
#[clap(short,
long)]
batch_size
:
Option
<
Vec
<
u32
>>
,
#[clap(default_value
=
"10"
,
short,
long,
env)]
...
...
@@ -36,6 +38,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Pattern match configuration
let
Args
{
tokenizer_name
,
revision
,
batch_size
,
sequence_length
,
decode_length
,
...
...
@@ -59,10 +62,19 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing
::
info!
(
"Found local tokenizer"
);
Tokenizer
::
from_file
(
local_path
.join
(
"tokenizer.json"
))
.unwrap
()
}
else
{
tracing
::
info!
(
"Downloading tokenizer"
);
// Parse Huggingface hub token
let
auth_token
=
std
::
env
::
var
(
"HUGGING_FACE_HUB_TOKEN"
)
.ok
();
// Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime
tracing
::
info!
(
"Downloading tokenizer"
);
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
None
)
.unwrap
()
let
params
=
FromPretrainedParameters
{
revision
,
auth_token
,
..
Default
::
default
()
};
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
Some
(
params
))
.unwrap
()
};
tracing
::
info!
(
"Tokenizer loaded"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment