Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
343437c7
Unverified
Commit
343437c7
authored
Apr 21, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 21, 2023
Browse files
feat(router): add device and dtype info (#215)
parent
ac8c0f6f
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
117 additions
and
15 deletions
+117
-15
proto/generate.proto
proto/generate.proto
+11
-0
router/client/src/client.rs
router/client/src/client.rs
+8
-0
router/client/src/lib.rs
router/client/src/lib.rs
+1
-0
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+12
-1
router/src/lib.rs
router/src/lib.rs
+5
-1
router/src/main.rs
router/src/main.rs
+9
-3
router/src/server.rs
router/src/server.rs
+12
-4
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+5
-1
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+5
-1
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+5
-1
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+4
-0
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+2
-0
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+7
-1
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+2
-0
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+2
-0
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+13
-0
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+2
-0
server/text_generation_server/models/santacoder.py
server/text_generation_server/models/santacoder.py
+5
-1
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+5
-1
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+2
-0
No files found.
proto/generate.proto
View file @
343437c7
...
@@ -3,6 +3,8 @@ syntax = "proto3";
...
@@ -3,6 +3,8 @@ syntax = "proto3";
package
generate
.
v1
;
package
generate
.
v1
;
service
TextGenerationService
{
service
TextGenerationService
{
/// Model Info
rpc
Info
(
InfoRequest
)
returns
(
InfoResponse
)
{}
/// Service discovery
/// Service discovery
rpc
ServiceDiscovery
(
ServiceDiscoveryRequest
)
returns
(
ServiceDiscoveryResponse
)
{}
rpc
ServiceDiscovery
(
ServiceDiscoveryRequest
)
returns
(
ServiceDiscoveryResponse
)
{}
/// Empties batch cache
/// Empties batch cache
...
@@ -13,6 +15,15 @@ service TextGenerationService {
...
@@ -13,6 +15,15 @@ service TextGenerationService {
rpc
Decode
(
DecodeRequest
)
returns
(
DecodeResponse
);
rpc
Decode
(
DecodeRequest
)
returns
(
DecodeResponse
);
}
}
/// Empty request
message
InfoRequest
{}
message
InfoResponse
{
bool
requires_padding
=
1
;
string
dtype
=
2
;
string
device_type
=
3
;
}
/// Empty request
/// Empty request
message
ServiceDiscoveryRequest
{}
message
ServiceDiscoveryRequest
{}
...
...
router/client/src/client.rs
View file @
343437c7
...
@@ -54,6 +54,14 @@ impl Client {
...
@@ -54,6 +54,14 @@ impl Client {
Ok
(
urls
)
Ok
(
urls
)
}
}
/// Get model info
#[instrument(skip(self))]
pub
async
fn
info
(
&
mut
self
)
->
Result
<
InfoResponse
>
{
let
request
=
tonic
::
Request
::
new
(
InfoRequest
{})
.inject_context
();
let
response
=
self
.stub
.info
(
request
)
.await
?
.into_inner
();
Ok
(
response
)
}
/// Clear the past generations cache
/// Clear the past generations cache
#[instrument(skip(self))]
#[instrument(skip(self))]
pub
async
fn
clear_cache
(
&
mut
self
,
batch_id
:
Option
<
u64
>
)
->
Result
<
()
>
{
pub
async
fn
clear_cache
(
&
mut
self
,
batch_id
:
Option
<
u64
>
)
->
Result
<
()
>
{
...
...
router/client/src/lib.rs
View file @
343437c7
...
@@ -6,6 +6,7 @@ mod pb;
...
@@ -6,6 +6,7 @@ mod pb;
mod
sharded_client
;
mod
sharded_client
;
pub
use
client
::
Client
;
pub
use
client
::
Client
;
pub
use
pb
::
generate
::
v1
::
InfoResponse
as
ShardInfo
;
pub
use
pb
::
generate
::
v1
::{
pub
use
pb
::
generate
::
v1
::{
Batch
,
FinishReason
,
GeneratedText
,
Generation
,
NextTokenChooserParameters
,
PrefillTokens
,
Batch
,
FinishReason
,
GeneratedText
,
Generation
,
NextTokenChooserParameters
,
PrefillTokens
,
Request
,
StoppingCriteriaParameters
,
Request
,
StoppingCriteriaParameters
,
...
...
router/client/src/sharded_client.rs
View file @
343437c7
/// Multi shard Client
/// Multi shard Client
use
crate
::
Result
;
use
crate
::
Result
;
use
crate
::{
Batch
,
Client
,
Generation
};
use
crate
::{
Batch
,
Client
,
Generation
,
ShardInfo
};
use
futures
::
future
::
join_all
;
use
futures
::
future
::
join_all
;
use
tonic
::
transport
::
Uri
;
use
tonic
::
transport
::
Uri
;
use
tracing
::
instrument
;
use
tracing
::
instrument
;
...
@@ -37,6 +37,17 @@ impl ShardedClient {
...
@@ -37,6 +37,17 @@ impl ShardedClient {
Self
::
from_master_client
(
master_client
)
.await
Self
::
from_master_client
(
master_client
)
.await
}
}
/// Get the model info
#[instrument(skip(self))]
pub
async
fn
info
(
&
mut
self
)
->
Result
<
ShardInfo
>
{
let
futures
:
Vec
<
_
>
=
self
.clients
.iter_mut
()
.map
(|
client
|
client
.info
())
.collect
();
join_all
(
futures
)
.await
.pop
()
.unwrap
()
}
/// Clear the past generations cache
/// Clear the past generations cache
#[instrument(skip(self))]
#[instrument(skip(self))]
pub
async
fn
clear_cache
(
&
mut
self
,
batch_id
:
Option
<
u64
>
)
->
Result
<
()
>
{
pub
async
fn
clear_cache
(
&
mut
self
,
batch_id
:
Option
<
u64
>
)
->
Result
<
()
>
{
...
...
router/src/lib.rs
View file @
343437c7
...
@@ -12,7 +12,7 @@ use validation::Validation;
...
@@ -12,7 +12,7 @@ use validation::Validation;
/// Hub type
/// Hub type
#[derive(Clone,
Debug,
Deserialize)]
#[derive(Clone,
Debug,
Deserialize)]
pub
struct
ModelInfo
{
pub
struct
Hub
ModelInfo
{
#[serde(rename(deserialize
=
"id"
))]
#[serde(rename(deserialize
=
"id"
))]
pub
model_id
:
String
,
pub
model_id
:
String
,
pub
sha
:
Option
<
String
>
,
pub
sha
:
Option
<
String
>
,
...
@@ -25,6 +25,10 @@ pub struct Info {
...
@@ -25,6 +25,10 @@ pub struct Info {
pub
model_id
:
String
,
pub
model_id
:
String
,
#[schema(nullable
=
true
,
example
=
"e985a63cdc139290c5f700ff1929f0b5942cced2"
)]
#[schema(nullable
=
true
,
example
=
"e985a63cdc139290c5f700ff1929f0b5942cced2"
)]
pub
model_sha
:
Option
<
String
>
,
pub
model_sha
:
Option
<
String
>
,
#[schema(example
=
"torch.float16"
)]
pub
model_dtype
:
String
,
#[schema(example
=
"cuda"
)]
pub
model_device_type
:
String
,
#[schema(nullable
=
true
,
example
=
"text-generation"
)]
#[schema(nullable
=
true
,
example
=
"text-generation"
)]
pub
model_pipeline_tag
:
Option
<
String
>
,
pub
model_pipeline_tag
:
Option
<
String
>
,
#[schema(example
=
"0.5.0"
)]
#[schema(example
=
"0.5.0"
)]
...
...
router/src/main.rs
View file @
343437c7
...
@@ -10,7 +10,7 @@ use opentelemetry_otlp::WithExportConfig;
...
@@ -10,7 +10,7 @@ use opentelemetry_otlp::WithExportConfig;
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
net
::{
IpAddr
,
Ipv4Addr
,
SocketAddr
};
use
std
::
path
::
Path
;
use
std
::
path
::
Path
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_router
::{
server
,
ModelInfo
};
use
text_generation_router
::{
server
,
Hub
ModelInfo
};
use
tokenizers
::{
FromPretrainedParameters
,
Tokenizer
};
use
tokenizers
::{
FromPretrainedParameters
,
Tokenizer
};
use
tower_http
::
cors
::
AllowOrigin
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
...
@@ -128,7 +128,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -128,7 +128,7 @@ fn main() -> Result<(), std::io::Error> {
// Get Model info
// Get Model info
let
model_info
=
match
local_model
{
let
model_info
=
match
local_model
{
true
=>
ModelInfo
{
true
=>
Hub
ModelInfo
{
model_id
:
tokenizer_name
.clone
(),
model_id
:
tokenizer_name
.clone
(),
sha
:
None
,
sha
:
None
,
pipeline_tag
:
None
,
pipeline_tag
:
None
,
...
@@ -154,6 +154,11 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -154,6 +154,11 @@ fn main() -> Result<(), std::io::Error> {
.clear_cache
(
None
)
.clear_cache
(
None
)
.await
.await
.expect
(
"Unable to clear cache"
);
.expect
(
"Unable to clear cache"
);
// Get info from the shard
let
shard_info
=
sharded_client
.info
()
.await
.expect
(
"Unable to get shard info"
);
tracing
::
info!
(
"Connected"
);
tracing
::
info!
(
"Connected"
);
// Binds on localhost
// Binds on localhost
...
@@ -162,6 +167,7 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -162,6 +167,7 @@ fn main() -> Result<(), std::io::Error> {
// Run server
// Run server
server
::
run
(
server
::
run
(
model_info
,
model_info
,
shard_info
,
compat_return_full_text
,
compat_return_full_text
,
max_concurrent_requests
,
max_concurrent_requests
,
max_best_of
,
max_best_of
,
...
@@ -237,7 +243,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
...
@@ -237,7 +243,7 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
}
}
/// get model info from the Huggingface Hub
/// get model info from the Huggingface Hub
pub
async
fn
get_model_info
(
model_id
:
&
str
,
revision
:
&
str
,
token
:
Option
<
String
>
)
->
ModelInfo
{
pub
async
fn
get_model_info
(
model_id
:
&
str
,
revision
:
&
str
,
token
:
Option
<
String
>
)
->
Hub
ModelInfo
{
let
client
=
reqwest
::
Client
::
new
();
let
client
=
reqwest
::
Client
::
new
();
let
mut
builder
=
client
.get
(
format!
(
let
mut
builder
=
client
.get
(
format!
(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
...
...
router/src/server.rs
View file @
343437c7
...
@@ -3,7 +3,7 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
...
@@ -3,7 +3,7 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use
crate
::
validation
::
ValidationError
;
use
crate
::
validation
::
ValidationError
;
use
crate
::{
use
crate
::{
BestOfSequence
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
BestOfSequence
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Inf
er
,
Inf
o
,
Model
Info
,
PrefillToken
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
HubModel
Inf
o
,
Inf
er
,
Info
,
PrefillToken
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
...
@@ -18,7 +18,7 @@ use futures::Stream;
...
@@ -18,7 +18,7 @@ use futures::Stream;
use
metrics_exporter_prometheus
::{
Matcher
,
PrometheusBuilder
,
PrometheusHandle
};
use
metrics_exporter_prometheus
::{
Matcher
,
PrometheusBuilder
,
PrometheusHandle
};
use
std
::
convert
::
Infallible
;
use
std
::
convert
::
Infallible
;
use
std
::
net
::
SocketAddr
;
use
std
::
net
::
SocketAddr
;
use
text_generation_client
::
ShardedClient
;
use
text_generation_client
::
{
ShardInfo
,
ShardedClient
}
;
use
tokenizers
::
Tokenizer
;
use
tokenizers
::
Tokenizer
;
use
tokio
::
signal
;
use
tokio
::
signal
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
@@ -78,13 +78,19 @@ async fn compat_generate(
...
@@ -78,13 +78,19 @@ async fn compat_generate(
responses((status
=
200
,
description
=
"Served model info"
,
body
=
Info))
responses((status
=
200
,
description
=
"Served model info"
,
body
=
Info))
)]
)]
#[instrument]
#[instrument]
async
fn
get_model_info
(
model_info
:
Extension
<
ModelInfo
>
)
->
Json
<
Info
>
{
async
fn
get_model_info
(
model_info
:
Extension
<
HubModelInfo
>
,
shard_info
:
Extension
<
ShardInfo
>
,
)
->
Json
<
Info
>
{
let
model_info
=
model_info
.0
;
let
model_info
=
model_info
.0
;
let
shard_info
=
shard_info
.0
;
let
info
=
Info
{
let
info
=
Info
{
version
:
env!
(
"CARGO_PKG_VERSION"
),
version
:
env!
(
"CARGO_PKG_VERSION"
),
sha
:
option_env!
(
"VERGEN_GIT_SHA"
),
sha
:
option_env!
(
"VERGEN_GIT_SHA"
),
model_id
:
model_info
.model_id
,
model_id
:
model_info
.model_id
,
model_sha
:
model_info
.sha
,
model_sha
:
model_info
.sha
,
model_dtype
:
shard_info
.dtype
,
model_device_type
:
shard_info
.device_type
,
model_pipeline_tag
:
model_info
.pipeline_tag
,
model_pipeline_tag
:
model_info
.pipeline_tag
,
};
};
Json
(
info
)
Json
(
info
)
...
@@ -497,7 +503,8 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
...
@@ -497,7 +503,8 @@ async fn metrics(prom_handle: Extension<PrometheusHandle>) -> String {
/// Serving method
/// Serving method
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
pub
async
fn
run
(
pub
async
fn
run
(
model_info
:
ModelInfo
,
model_info
:
HubModelInfo
,
shard_info
:
ShardInfo
,
compat_return_full_text
:
bool
,
compat_return_full_text
:
bool
,
max_concurrent_requests
:
usize
,
max_concurrent_requests
:
usize
,
max_best_of
:
usize
,
max_best_of
:
usize
,
...
@@ -641,6 +648,7 @@ pub async fn run(
...
@@ -641,6 +648,7 @@ pub async fn run(
// Prometheus metrics route
// Prometheus metrics route
.route
(
"/metrics"
,
get
(
metrics
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
model_info
))
.layer
(
Extension
(
model_info
))
.layer
(
Extension
(
shard_info
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
compat_return_full_text
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
infer
))
.layer
(
Extension
(
prom_handle
))
.layer
(
Extension
(
prom_handle
))
...
...
server/text_generation_server/models/bloom.py
View file @
343437c7
...
@@ -100,7 +100,11 @@ class BLOOMSharded(BLOOM):
...
@@ -100,7 +100,11 @@ class BLOOMSharded(BLOOM):
self
.
model
=
model
.
eval
()
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
decode_buffer
=
1
,
)
)
@
staticmethod
@
staticmethod
...
...
server/text_generation_server/models/causal_lm.py
View file @
343437c7
...
@@ -400,7 +400,11 @@ class CausalLM(Model):
...
@@ -400,7 +400,11 @@ class CausalLM(Model):
)
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
decode_buffer
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
decode_buffer
=
decode_buffer
,
)
)
@
property
@
property
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
343437c7
...
@@ -343,7 +343,11 @@ class FlashCausalLM(Model):
...
@@ -343,7 +343,11 @@ class FlashCausalLM(Model):
)
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
decode_buffer
tokenizer
=
tokenizer
,
requires_padding
=
False
,
dtype
=
dtype
,
device
=
device
,
decode_buffer
=
decode_buffer
,
)
)
@
property
@
property
...
...
server/text_generation_server/models/flash_llama.py
View file @
343437c7
...
@@ -63,6 +63,8 @@ class FlashLlama(FlashCausalLM):
...
@@ -63,6 +63,8 @@ class FlashLlama(FlashCausalLM):
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
False
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
@@ -184,6 +186,8 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -184,6 +186,8 @@ class FlashLlamaSharded(FlashLlama):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
False
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
343437c7
...
@@ -70,6 +70,8 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -70,6 +70,8 @@ class FlashNeoXSharded(FlashNeoX):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
False
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
343437c7
...
@@ -65,7 +65,11 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -65,7 +65,11 @@ class FlashSantacoder(FlashCausalLM):
self
.
model
=
model
.
eval
().
to
(
device
)
self
.
model
=
model
.
eval
().
to
(
device
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
requires_padding
=
False
,
dtype
=
dtype
,
device
=
device
,
decode_buffer
=
1
,
)
)
@
staticmethod
@
staticmethod
...
@@ -206,6 +210,8 @@ class FlashSantacoderSharded(FlashSantacoder):
...
@@ -206,6 +210,8 @@ class FlashSantacoderSharded(FlashSantacoder):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
False
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/galactica.py
View file @
343437c7
...
@@ -228,6 +228,8 @@ class GalacticaSharded(Galactica):
...
@@ -228,6 +228,8 @@ class GalacticaSharded(Galactica):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/gpt_neox.py
View file @
343437c7
...
@@ -72,6 +72,8 @@ class GPTNeoxSharded(CausalLM):
...
@@ -72,6 +72,8 @@ class GPTNeoxSharded(CausalLM):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/model.py
View file @
343437c7
...
@@ -5,6 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
...
@@ -5,6 +5,7 @@ from typing import List, Tuple, Optional, TypeVar, Type
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
text_generation_server.models.types
import
Batch
,
GeneratedText
from
text_generation_server.models.types
import
Batch
,
GeneratedText
from
text_generation_server.pb.generate_pb2
import
InfoResponse
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
...
@@ -13,6 +14,8 @@ class Model(ABC):
...
@@ -13,6 +14,8 @@ class Model(ABC):
def
__init__
(
def
__init__
(
self
,
self
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
requires_padding
:
bool
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
decode_buffer
:
int
=
3
,
decode_buffer
:
int
=
3
,
):
):
...
@@ -21,9 +24,19 @@ class Model(ABC):
...
@@ -21,9 +24,19 @@ class Model(ABC):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
requires_padding
=
requires_padding
self
.
dtype
=
dtype
self
.
device
=
device
self
.
device
=
device
self
.
decode_buffer
=
decode_buffer
self
.
decode_buffer
=
decode_buffer
@
property
def
info
(
self
)
->
InfoResponse
:
return
InfoResponse
(
requires_padding
=
self
.
requires_padding
,
dtype
=
str
(
self
.
dtype
),
device_type
=
self
.
device
.
type
,
)
@
property
@
property
@
abstractmethod
@
abstractmethod
def
batch_type
(
self
)
->
Type
[
B
]:
def
batch_type
(
self
)
->
Type
[
B
]:
...
...
server/text_generation_server/models/opt.py
View file @
343437c7
...
@@ -88,6 +88,8 @@ class OPTSharded(OPT):
...
@@ -88,6 +88,8 @@ class OPTSharded(OPT):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
server/text_generation_server/models/santacoder.py
View file @
343437c7
...
@@ -54,7 +54,11 @@ class SantaCoder(CausalLM):
...
@@ -54,7 +54,11 @@ class SantaCoder(CausalLM):
)
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
decode_buffer
=
1
,
)
)
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
def
decode
(
self
,
generated_ids
:
List
[
int
])
->
str
:
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
343437c7
...
@@ -460,7 +460,11 @@ class Seq2SeqLM(Model):
...
@@ -460,7 +460,11 @@ class Seq2SeqLM(Model):
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
super
(
Seq2SeqLM
,
self
).
__init__
(
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
decode_buffer
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
decode_buffer
=
decode_buffer
,
)
)
@
property
@
property
...
...
server/text_generation_server/models/t5.py
View file @
343437c7
...
@@ -72,6 +72,8 @@ class T5Sharded(Seq2SeqLM):
...
@@ -72,6 +72,8 @@ class T5Sharded(Seq2SeqLM):
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
Seq2SeqLM
,
self
).
__init__
(
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
requires_padding
=
True
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
)
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment