Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
09b7c26b
Unverified
Commit
09b7c26b
authored
Feb 08, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 08, 2024
Browse files
feat(server): add frequency penalty (#1541)
parent
39af000c
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
296 additions
and
78 deletions
+296
-78
Cargo.lock
Cargo.lock
+36
-2
benchmark/src/lib.rs
benchmark/src/lib.rs
+3
-0
benchmark/src/main.rs
benchmark/src/main.rs
+7
-0
benchmark/src/table.rs
benchmark/src/table.rs
+2
-0
integration-tests/models/test_mamba.py
integration-tests/models/test_mamba.py
+9
-2
proto/generate.proto
proto/generate.proto
+2
-0
router/client/src/client.rs
router/client/src/client.rs
+1
-0
router/src/health.rs
router/src/health.rs
+1
-0
router/src/lib.rs
router/src/lib.rs
+77
-9
router/src/queue.rs
router/src/queue.rs
+1
-0
router/src/server.rs
router/src/server.rs
+12
-6
router/src/validation.rs
router/src/validation.rs
+9
-0
server/tests/utils/test_tokens.py
server/tests/utils/test_tokens.py
+1
-1
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+1
-0
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+5
-2
server/text_generation_server/models/custom_modeling/mamba_modeling.py
...eneration_server/models/custom_modeling/mamba_modeling.py
+75
-32
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+5
-3
server/text_generation_server/models/mamba.py
server/text_generation_server/models/mamba.py
+41
-18
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+5
-2
server/text_generation_server/models/types.py
server/text_generation_server/models/types.py
+3
-1
No files found.
Cargo.lock
View file @
09b7c26b
...
@@ -2787,7 +2787,7 @@ dependencies = [
...
@@ -2787,7 +2787,7 @@ dependencies = [
"tabled",
"tabled",
"text-generation-client",
"text-generation-client",
"thiserror",
"thiserror",
"tokenizers",
"tokenizers
0.14.1
",
"tokio",
"tokio",
"tracing",
"tracing",
"tracing-subscriber",
"tracing-subscriber",
...
@@ -2850,7 +2850,7 @@ dependencies = [
...
@@ -2850,7 +2850,7 @@ dependencies = [
"serde_json",
"serde_json",
"text-generation-client",
"text-generation-client",
"thiserror",
"thiserror",
"tokenizers",
"tokenizers
0.15.1
",
"tokio",
"tokio",
"tokio-stream",
"tokio-stream",
"tower-http",
"tower-http",
...
@@ -2972,6 +2972,40 @@ dependencies = [
...
@@ -2972,6 +2972,40 @@ dependencies = [
"unicode_categories",
"unicode_categories",
]
]
[[package]]
name = "tokenizers"
version = "0.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db445cceba5dfeb0f9702be7d6bfd91801ddcbe8fe8722defe7f2e96da75812"
dependencies = [
"aho-corasick",
"clap",
"derive_builder",
"esaxx-rs",
"getrandom",
"hf-hub",
"indicatif",
"itertools 0.11.0",
"lazy_static",
"log",
"macro_rules_attribute",
"monostate",
"onig",
"paste",
"rand",
"rayon",
"rayon-cond",
"regex",
"regex-syntax 0.7.5",
"serde",
"serde_json",
"spm_precompiled",
"thiserror",
"unicode-normalization-alignments",
"unicode-segmentation",
"unicode_categories",
]
[[package]]
[[package]]
name = "tokio"
name = "tokio"
version = "1.35.1"
version = "1.35.1"
...
...
benchmark/src/lib.rs
View file @
09b7c26b
...
@@ -30,6 +30,7 @@ pub async fn run(
...
@@ -30,6 +30,7 @@ pub async fn run(
top_p
:
Option
<
f32
>
,
top_p
:
Option
<
f32
>
,
typical_p
:
Option
<
f32
>
,
typical_p
:
Option
<
f32
>
,
repetition_penalty
:
Option
<
f32
>
,
repetition_penalty
:
Option
<
f32
>
,
frequency_penalty
:
Option
<
f32
>
,
watermark
:
bool
,
watermark
:
bool
,
do_sample
:
bool
,
do_sample
:
bool
,
client
:
ShardedClient
,
client
:
ShardedClient
,
...
@@ -42,6 +43,7 @@ pub async fn run(
...
@@ -42,6 +43,7 @@ pub async fn run(
do_sample
,
do_sample
,
seed
:
0
,
seed
:
0
,
repetition_penalty
:
repetition_penalty
.unwrap_or
(
1.0
),
repetition_penalty
:
repetition_penalty
.unwrap_or
(
1.0
),
frequency_penalty
:
frequency_penalty
.unwrap_or
(
0.0
),
watermark
,
watermark
,
};
};
...
@@ -140,6 +142,7 @@ pub async fn run(
...
@@ -140,6 +142,7 @@ pub async fn run(
top_p
,
top_p
,
typical_p
,
typical_p
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
,
watermark
,
watermark
,
do_sample
,
do_sample
,
);
);
...
...
benchmark/src/main.rs
View file @
09b7c26b
...
@@ -84,6 +84,11 @@ struct Args {
...
@@ -84,6 +84,11 @@ struct Args {
#[clap(long,
env)]
#[clap(long,
env)]
repetition_penalty
:
Option
<
f32
>
,
repetition_penalty
:
Option
<
f32
>
,
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long,
env)]
frequency_penalty
:
Option
<
f32
>
,
/// Generation parameter in case you want to specifically test/debug particular
/// Generation parameter in case you want to specifically test/debug particular
/// decoding strategies, for full doc refer to the `text-generation-server`
/// decoding strategies, for full doc refer to the `text-generation-server`
#[clap(long,
env)]
#[clap(long,
env)]
...
@@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
...
@@ -119,6 +124,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p
,
top_p
,
typical_p
,
typical_p
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
,
watermark
,
watermark
,
do_sample
,
do_sample
,
master_shard_uds_path
,
master_shard_uds_path
,
...
@@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
...
@@ -187,6 +193,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
top_p
,
top_p
,
typical_p
,
typical_p
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
,
watermark
,
watermark
,
do_sample
,
do_sample
,
sharded_client
,
sharded_client
,
...
...
benchmark/src/table.rs
View file @
09b7c26b
...
@@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
...
@@ -15,6 +15,7 @@ pub(crate) fn parameters_table(
top_p
:
Option
<
f32
>
,
top_p
:
Option
<
f32
>
,
typical_p
:
Option
<
f32
>
,
typical_p
:
Option
<
f32
>
,
repetition_penalty
:
Option
<
f32
>
,
repetition_penalty
:
Option
<
f32
>
,
frequency_penalty
:
Option
<
f32
>
,
watermark
:
bool
,
watermark
:
bool
,
do_sample
:
bool
,
do_sample
:
bool
,
)
->
Table
{
)
->
Table
{
...
@@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
...
@@ -33,6 +34,7 @@ pub(crate) fn parameters_table(
builder
.push_record
([
"Top P"
,
&
format!
(
"{top_p:?}"
)]);
builder
.push_record
([
"Top P"
,
&
format!
(
"{top_p:?}"
)]);
builder
.push_record
([
"Typical P"
,
&
format!
(
"{typical_p:?}"
)]);
builder
.push_record
([
"Typical P"
,
&
format!
(
"{typical_p:?}"
)]);
builder
.push_record
([
"Repetition Penalty"
,
&
format!
(
"{repetition_penalty:?}"
)]);
builder
.push_record
([
"Repetition Penalty"
,
&
format!
(
"{repetition_penalty:?}"
)]);
builder
.push_record
([
"Frequency Penalty"
,
&
format!
(
"{frequency_penalty:?}"
)]);
builder
.push_record
([
"Watermark"
,
&
watermark
.to_string
()]);
builder
.push_record
([
"Watermark"
,
&
watermark
.to_string
()]);
builder
.push_record
([
"Do Sample"
,
&
do_sample
.to_string
()]);
builder
.push_record
([
"Do Sample"
,
&
do_sample
.to_string
()]);
...
...
integration-tests/models/test_mamba.py
View file @
09b7c26b
...
@@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
...
@@ -24,6 +24,7 @@ async def test_mamba(fused_kernel_mamba, response_snapshot):
assert
response
.
generated_text
==
"
\n\n
Deep learning is a new type of machine"
assert
response
.
generated_text
==
"
\n\n
Deep learning is a new type of machine"
assert
response
==
response_snapshot
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_mamba_all_params
(
fused_kernel_mamba
,
response_snapshot
):
async
def
test_mamba_all_params
(
fused_kernel_mamba
,
response_snapshot
):
...
@@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
...
@@ -44,13 +45,19 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
)
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
.
details
.
generated_tokens
==
10
assert
response
.
generated_text
==
"blue, red, yellow,
\n
and orange (in the order they appear in"
assert
(
response
.
generated_text
==
"blue, red, yellow,
\n
and orange (in the order they appear in"
)
assert
response
==
response_snapshot
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
@
pytest
.
mark
.
private
async
def
test_mamba_load
(
fused_kernel_mamba
,
generate_load
,
response_snapshot
):
async
def
test_mamba_load
(
fused_kernel_mamba
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
fused_kernel_mamba
,
"What is Deep Learning?"
,
max_new_tokens
=
10
,
n
=
4
)
responses
=
await
generate_load
(
fused_kernel_mamba
,
"What is Deep Learning?"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
...
...
proto/generate.proto
View file @
09b7c26b
...
@@ -66,6 +66,8 @@ message NextTokenChooserParameters {
...
@@ -66,6 +66,8 @@ message NextTokenChooserParameters {
uint64
seed
=
6
;
uint64
seed
=
6
;
/// repetition penalty
/// repetition penalty
float
repetition_penalty
=
7
;
float
repetition_penalty
=
7
;
/// frequency penalty
float
frequency_penalty
=
9
;
/// token watermarking using "A Watermark for Large Language Models"
/// token watermarking using "A Watermark for Large Language Models"
bool
watermark
=
8
;
bool
watermark
=
8
;
}
}
...
...
router/client/src/client.rs
View file @
09b7c26b
...
@@ -125,6 +125,7 @@ impl Client {
...
@@ -125,6 +125,7 @@ impl Client {
do_sample
:
false
,
do_sample
:
false
,
seed
:
0
,
seed
:
0
,
repetition_penalty
:
1.2
,
repetition_penalty
:
1.2
,
frequency_penalty
:
0.1
,
watermark
:
true
,
watermark
:
true
,
}),
}),
stopping_parameters
:
Some
(
StoppingCriteriaParameters
{
stopping_parameters
:
Some
(
StoppingCriteriaParameters
{
...
...
router/src/health.rs
View file @
09b7c26b
...
@@ -43,6 +43,7 @@ impl Health {
...
@@ -43,6 +43,7 @@ impl Health {
do_sample
:
false
,
do_sample
:
false
,
seed
:
0
,
seed
:
0
,
repetition_penalty
:
1.0
,
repetition_penalty
:
1.0
,
frequency_penalty
:
0.0
,
watermark
:
false
,
watermark
:
false
,
}),
}),
stopping_parameters
:
Some
(
StoppingCriteriaParameters
{
stopping_parameters
:
Some
(
StoppingCriteriaParameters
{
...
...
router/src/lib.rs
View file @
09b7c26b
...
@@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
...
@@ -106,6 +106,14 @@ pub(crate) struct GenerateParameters {
)]
)]
pub
repetition_penalty
:
Option
<
f32
>
,
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(default)]
#[serde(default)]
#[schema(
exclusive_minimum
=
-
2.0
,
nullable
=
true
,
default
=
"null"
,
example
=
0.1
)]
pub
frequency_penalty
:
Option
<
f32
>
,
#[serde(default)]
#[schema(exclusive_minimum
=
0
,
nullable
=
true
,
default
=
"null"
,
example
=
10
)]
#[schema(exclusive_minimum
=
0
,
nullable
=
true
,
default
=
"null"
,
example
=
10
)]
pub
top_k
:
Option
<
i32
>
,
pub
top_k
:
Option
<
i32
>
,
#[serde(default)]
#[serde(default)]
...
@@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
...
@@ -172,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
best_of
:
None
,
best_of
:
None
,
temperature
:
None
,
temperature
:
None
,
repetition_penalty
:
None
,
repetition_penalty
:
None
,
frequency_penalty
:
None
,
top_k
:
None
,
top_k
:
None
,
top_p
:
None
,
top_p
:
None
,
typical_p
:
None
,
typical_p
:
None
,
...
@@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
...
@@ -205,10 +214,71 @@ pub(crate) struct ChatCompletion {
pub
(
crate
)
struct
ChatCompletionComplete
{
pub
(
crate
)
struct
ChatCompletionComplete
{
pub
index
:
u32
,
pub
index
:
u32
,
pub
message
:
Message
,
pub
message
:
Message
,
pub
logprobs
:
Option
<
Vec
<
f32
>
>
,
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
pub
finish_reason
:
String
,
pub
finish_reason
:
String
,
}
}
#[derive(Clone,
Deserialize,
Serialize,
ToSchema)]
pub
(
crate
)
struct
ChatCompletionLogprobs
{
content
:
Vec
<
ChatCompletionLogprob
>
,
}
impl
From
<
(
Token
,
Vec
<
Token
>
)
>
for
ChatCompletionLogprobs
{
fn
from
(
value
:
(
Token
,
Vec
<
Token
>
))
->
Self
{
let
(
token
,
top_tokens
)
=
value
;
Self
{
content
:
vec!
[
ChatCompletionLogprob
{
token
:
token
.text
,
logprob
:
token
.logprob
,
top_logprobs
:
top_tokens
.into_iter
()
.map
(|
t
|
ChatCompletionTopLogprob
{
token
:
t
.text
,
logprob
:
t
.logprob
,
})
.collect
(),
}],
}
}
}
impl
From
<
(
Vec
<
Token
>
,
Vec
<
Vec
<
Token
>>
)
>
for
ChatCompletionLogprobs
{
fn
from
(
value
:
(
Vec
<
Token
>
,
Vec
<
Vec
<
Token
>>
))
->
Self
{
let
(
tokens
,
top_tokens
)
=
value
;
Self
{
content
:
tokens
.into_iter
()
.zip
(
top_tokens
)
.map
(|(
t
,
top_t
)|
ChatCompletionLogprob
{
token
:
t
.text
,
logprob
:
t
.logprob
,
top_logprobs
:
top_t
.into_iter
()
.map
(|
t
|
ChatCompletionTopLogprob
{
token
:
t
.text
,
logprob
:
t
.logprob
,
})
.collect
(),
})
.collect
(),
}
}
}
#[derive(Clone,
Deserialize,
Serialize,
ToSchema)]
pub
(
crate
)
struct
ChatCompletionLogprob
{
token
:
String
,
logprob
:
f32
,
top_logprobs
:
Vec
<
ChatCompletionTopLogprob
>
,
}
#[derive(Clone,
Deserialize,
Serialize,
ToSchema)]
pub
(
crate
)
struct
ChatCompletionTopLogprob
{
token
:
String
,
logprob
:
f32
,
}
#[derive(Clone,
Deserialize,
Serialize)]
#[derive(Clone,
Deserialize,
Serialize)]
pub
(
crate
)
struct
Usage
{
pub
(
crate
)
struct
Usage
{
pub
prompt_tokens
:
u32
,
pub
prompt_tokens
:
u32
,
...
@@ -238,7 +308,7 @@ impl ChatCompletion {
...
@@ -238,7 +308,7 @@ impl ChatCompletion {
content
:
output
,
content
:
output
,
},
},
logprobs
:
return_logprobs
logprobs
:
return_logprobs
.then
(||
details
.tokens
.iter
()
.map
(|
t
|
t
.logprob
)
.collect
(
)),
.then
(||
ChatCompletionLogprobs
::
from
((
details
.tokens
,
details
.top_tokens
)
)),
finish_reason
:
details
.finish_reason
.to_string
(),
finish_reason
:
details
.finish_reason
.to_string
(),
}],
}],
usage
:
Usage
{
usage
:
Usage
{
...
@@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
...
@@ -266,7 +336,7 @@ pub(crate) struct ChatCompletionChunk {
pub
(
crate
)
struct
ChatCompletionChoice
{
pub
(
crate
)
struct
ChatCompletionChoice
{
pub
index
:
u32
,
pub
index
:
u32
,
pub
delta
:
ChatCompletionDelta
,
pub
delta
:
ChatCompletionDelta
,
pub
logprobs
:
Option
<
f32
>
,
pub
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
pub
finish_reason
:
Option
<
String
>
,
pub
finish_reason
:
Option
<
String
>
,
}
}
...
@@ -285,7 +355,7 @@ impl ChatCompletionChunk {
...
@@ -285,7 +355,7 @@ impl ChatCompletionChunk {
delta
:
String
,
delta
:
String
,
created
:
u64
,
created
:
u64
,
index
:
u32
,
index
:
u32
,
logprobs
:
Option
<
f32
>
,
logprobs
:
Option
<
ChatCompletionLogprobs
>
,
finish_reason
:
Option
<
String
>
,
finish_reason
:
Option
<
String
>
,
)
->
Self
{
)
->
Self
{
Self
{
Self
{
...
@@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
...
@@ -319,8 +389,8 @@ pub(crate) struct ChatRequest {
/// UNUSED
/// UNUSED
#[schema(example
=
"mistralai/Mistral-7B-Instruct-v0.2"
)]
#[schema(example
=
"mistralai/Mistral-7B-Instruct-v0.2"
)]
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
/// ID of the model to use. See the model endpoint compatibility table for details on which models work with the Chat API.
pub
model
:
String
,
/* NOTE: UNUSED */
pub
model
:
String
,
/* NOTE: UNUSED */
/// A list of messages comprising the conversation so far.
/// A list of messages comprising the conversation so far.
#[serde(default
=
"default_request_messages"
)]
#[serde(default
=
"default_request_messages"
)]
pub
messages
:
Vec
<
Message
>
,
pub
messages
:
Vec
<
Message
>
,
...
@@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
...
@@ -346,7 +416,6 @@ pub(crate) struct ChatRequest {
#[schema(example
=
"false"
)]
#[schema(example
=
"false"
)]
pub
logprobs
:
Option
<
bool
>
,
pub
logprobs
:
Option
<
bool
>
,
/// UNUSED
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with
/// an associated log probability. logprobs must be set to true if this parameter is used.
/// an associated log probability. logprobs must be set to true if this parameter is used.
#[serde(default)]
#[serde(default)]
...
@@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
...
@@ -365,7 +434,6 @@ pub(crate) struct ChatRequest {
#[schema(nullable
=
true
,
example
=
"2"
)]
#[schema(nullable
=
true
,
example
=
"2"
)]
pub
n
:
Option
<
u32
>
,
pub
n
:
Option
<
u32
>
,
/// UNUSED
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
/// increasing the model's likelihood to talk about new topics
/// increasing the model's likelihood to talk about new topics
#[serde(default)]
#[serde(default)]
...
@@ -447,7 +515,7 @@ pub struct PrefillToken {
...
@@ -447,7 +515,7 @@ pub struct PrefillToken {
logprob
:
f32
,
logprob
:
f32
,
}
}
#[derive(Debug,
Serialize,
ToSchema)]
#[derive(Debug,
Serialize,
ToSchema
,
Clone
)]
pub
struct
Token
{
pub
struct
Token
{
#[schema(example
=
0
)]
#[schema(example
=
0
)]
id
:
u32
,
id
:
u32
,
...
...
router/src/queue.rs
View file @
09b7c26b
...
@@ -355,6 +355,7 @@ mod tests {
...
@@ -355,6 +355,7 @@ mod tests {
do_sample
:
false
,
do_sample
:
false
,
seed
:
0
,
seed
:
0
,
repetition_penalty
:
0.0
,
repetition_penalty
:
0.0
,
frequency_penalty
:
0.0
,
watermark
:
false
,
watermark
:
false
,
},
},
stopping_parameters
:
StoppingCriteriaParameters
{
stopping_parameters
:
StoppingCriteriaParameters
{
...
...
router/src/server.rs
View file @
09b7c26b
...
@@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
...
@@ -4,9 +4,10 @@ use crate::infer::{InferError, InferResponse, InferStreamResponse};
use
crate
::
validation
::
ValidationError
;
use
crate
::
validation
::
ValidationError
;
use
crate
::{
use
crate
::{
BestOfSequence
,
ChatCompletion
,
ChatCompletionChoice
,
ChatCompletionChunk
,
ChatCompletionDelta
,
BestOfSequence
,
ChatCompletion
,
ChatCompletionChoice
,
ChatCompletionChunk
,
ChatCompletionDelta
,
ChatRequest
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
ChatCompletionLogprobs
,
ChatRequest
,
CompatGenerateRequest
,
Details
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
HubModelInfo
,
HubTokenizerConfig
,
Infer
,
Info
,
Message
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
HubModelInfo
,
PrefillToken
,
SimpleToken
,
StreamDetails
,
StreamResponse
,
Token
,
TokenizeResponse
,
Validation
,
HubTokenizerConfig
,
Infer
,
Info
,
Message
,
PrefillToken
,
SimpleToken
,
StreamDetails
,
StreamResponse
,
Token
,
TokenizeResponse
,
Validation
,
};
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
...
@@ -570,8 +571,8 @@ async fn chat_completions(
...
@@ -570,8 +571,8 @@ async fn chat_completions(
let
stream
=
req
.stream
;
let
stream
=
req
.stream
;
let
max_new_tokens
=
req
.max_tokens
.or
(
Some
(
100
));
let
max_new_tokens
=
req
.max_tokens
.or
(
Some
(
100
));
let
repetition_penalty
=
req
let
repetition_penalty
=
req
.
f
re
qu
enc
y
_penalty
.
p
re
s
enc
e
_penalty
// rescale
f
re
quency
_penalty from (-2.0, 2.0) to (0.0, 4.0)
// rescale re
petition
_penalty from (-2.0, 2.0) to (0.0, 4.0)
.map
(|
x
|
x
+
2.0
);
.map
(|
x
|
x
+
2.0
);
let
logprobs
=
req
.logprobs
.unwrap_or
(
false
);
let
logprobs
=
req
.logprobs
.unwrap_or
(
false
);
let
seed
=
req
.seed
;
let
seed
=
req
.seed
;
...
@@ -599,6 +600,7 @@ async fn chat_completions(
...
@@ -599,6 +600,7 @@ async fn chat_completions(
best_of
:
None
,
best_of
:
None
,
temperature
:
req
.temperature
,
temperature
:
req
.temperature
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
:
req
.frequency_penalty
,
top_k
:
None
,
top_k
:
None
,
top_p
:
req
.top_p
,
top_p
:
req
.top_p
,
typical_p
:
None
,
typical_p
:
None
,
...
@@ -630,6 +632,10 @@ async fn chat_completions(
...
@@ -630,6 +632,10 @@ async fn chat_completions(
.unwrap_or_else
(|
_
|
std
::
time
::
Duration
::
from_secs
(
0
))
.unwrap_or_else
(|
_
|
std
::
time
::
Duration
::
from_secs
(
0
))
.as_secs
();
.as_secs
();
let
logprobs
=
logprobs
.then
(||
{
ChatCompletionLogprobs
::
from
((
stream_token
.token
.clone
(),
stream_token
.top_tokens
))
});
event
event
.json_data
(
ChatCompletionChunk
::
new
(
.json_data
(
ChatCompletionChunk
::
new
(
model_id
.clone
(),
model_id
.clone
(),
...
@@ -637,7 +643,7 @@ async fn chat_completions(
...
@@ -637,7 +643,7 @@ async fn chat_completions(
stream_token
.token.text
,
stream_token
.token.text
,
current_time
,
current_time
,
stream_token
.index
,
stream_token
.index
,
logprobs
.then_some
(
stream_token
.token.logprob
)
,
logprobs
,
stream_token
.details
.map
(|
d
|
d
.finish_reason
.to_string
()),
stream_token
.details
.map
(|
d
|
d
.finish_reason
.to_string
()),
))
))
.map_or_else
(
.map_or_else
(
...
...
router/src/validation.rs
View file @
09b7c26b
...
@@ -170,6 +170,7 @@ impl Validation {
...
@@ -170,6 +170,7 @@ impl Validation {
best_of
,
best_of
,
temperature
,
temperature
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
,
top_k
,
top_k
,
top_p
,
top_p
,
typical_p
,
typical_p
,
...
@@ -206,6 +207,11 @@ impl Validation {
...
@@ -206,6 +207,11 @@ impl Validation {
return
Err
(
ValidationError
::
RepetitionPenalty
);
return
Err
(
ValidationError
::
RepetitionPenalty
);
}
}
let
frequency_penalty
=
frequency_penalty
.unwrap_or
(
0.0
);
if
!
(
-
2.0
..=
2.0
)
.contains
(
&
frequency_penalty
)
{
return
Err
(
ValidationError
::
FrequencyPenalty
);
}
// Different because the proto default value is not a valid value
// Different because the proto default value is not a valid value
// for the user
// for the user
let
top_p
=
top_p
let
top_p
=
top_p
...
@@ -289,6 +295,7 @@ impl Validation {
...
@@ -289,6 +295,7 @@ impl Validation {
let
parameters
=
NextTokenChooserParameters
{
let
parameters
=
NextTokenChooserParameters
{
temperature
,
temperature
,
repetition_penalty
,
repetition_penalty
,
frequency_penalty
,
top_k
,
top_k
,
top_p
,
top_p
,
typical_p
,
typical_p
,
...
@@ -420,6 +427,8 @@ pub enum ValidationError {
...
@@ -420,6 +427,8 @@ pub enum ValidationError {
Temperature
,
Temperature
,
#[error(
"`repetition_penalty` must be strictly positive"
)]
#[error(
"`repetition_penalty` must be strictly positive"
)]
RepetitionPenalty
,
RepetitionPenalty
,
#[error(
"`frequency_penalty` must be >= -2.0 and <= 2.0"
)]
FrequencyPenalty
,
#[error(
"`top_p` must be > 0.0 and < 1.0"
)]
#[error(
"`top_p` must be > 0.0 and < 1.0"
)]
TopP
,
TopP
,
#[error(
"`top_k` must be strictly positive"
)]
#[error(
"`top_k` must be strictly positive"
)]
...
...
server/tests/utils/test_tokens.py
View file @
09b7c26b
...
@@ -70,7 +70,7 @@ def test_batch_top_tokens():
...
@@ -70,7 +70,7 @@ def test_batch_top_tokens():
# Now let's make second member of the batch be speculated
# Now let's make second member of the batch be speculated
inp_logprobs
=
torch
.
tensor
([[
-
1.0
,
-
3.0
,
-
4.0
,
-
2.0
,
-
3.0
]]
*
5
*
2
)
inp_logprobs
=
torch
.
tensor
([[
-
1.0
,
-
3.0
,
-
4.0
,
-
2.0
,
-
3.0
]]
*
5
*
2
)
accepted_ids
[
1
]
=
2
accepted_ids
[
1
]
=
2
topn_tok_ids
,
topn_tok_logprobs
=
batch_top_tokens
(
topn_tok_ids
,
topn_tok_logprobs
=
batch_top_tokens
(
top_n_tokens
,
top_n_tokens_tensor
,
inp_logprobs
,
accepted_ids
top_n_tokens
,
top_n_tokens_tensor
,
inp_logprobs
,
accepted_ids
)
)
...
...
server/text_generation_server/models/__init__.py
View file @
09b7c26b
...
@@ -86,6 +86,7 @@ except ImportError as e:
...
@@ -86,6 +86,7 @@ except ImportError as e:
if
MAMBA_AVAILABLE
:
if
MAMBA_AVAILABLE
:
__all__
.
append
(
Mamba
)
__all__
.
append
(
Mamba
)
def
get_model
(
def
get_model
(
model_id
:
str
,
model_id
:
str
,
revision
:
Optional
[
str
],
revision
:
Optional
[
str
],
...
...
server/text_generation_server/models/causal_lm.py
View file @
09b7c26b
...
@@ -696,14 +696,17 @@ class CausalLM(Model):
...
@@ -696,14 +696,17 @@ class CausalLM(Model):
if
top_n_tokens
>
0
:
if
top_n_tokens
>
0
:
all_top_tokens
=
[]
all_top_tokens
=
[]
for
(
top_token_ids
,
top_token_logprobs
)
in
zip
(
top_token_ids
,
top_token_logprobs
):
for
(
top_token_ids
,
top_token_logprobs
)
in
zip
(
top_token_ids
,
top_token_logprobs
):
toptoken_texts
=
self
.
tokenizer
.
batch_decode
(
toptoken_texts
=
self
.
tokenizer
.
batch_decode
(
top_token_ids
,
top_token_ids
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
skip_special_tokens
=
False
,
)
)
special_toptokens
=
[
special_toptokens
=
[
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
]
]
top_tokens
=
Tokens
(
top_tokens
=
Tokens
(
top_token_ids
,
top_token_ids
,
...
...
server/text_generation_server/models/custom_modeling/mamba_modeling.py
View file @
09b7c26b
...
@@ -19,6 +19,7 @@ from einops import rearrange
...
@@ -19,6 +19,7 @@ from einops import rearrange
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
from
causal_conv1d
import
causal_conv1d_fn
,
causal_conv1d_update
import
math
import
math
class
MambaConfig
(
PretrainedConfig
):
class
MambaConfig
(
PretrainedConfig
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
...
@@ -53,6 +54,7 @@ class MambaConfig(PretrainedConfig):
**
kwargs
,
**
kwargs
,
)
)
class
MambaBlock
(
nn
.
Module
):
class
MambaBlock
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
...
@@ -60,10 +62,14 @@ class MambaBlock(nn.Module):
...
@@ -60,10 +62,14 @@ class MambaBlock(nn.Module):
self
.
in_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.in_proj"
,
weights
,
bias
=
False
)
self
.
in_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.in_proj"
,
weights
,
bias
=
False
)
self
.
x_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.x_proj"
,
weights
,
bias
=
False
)
self
.
x_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.x_proj"
,
weights
,
bias
=
False
)
self
.
dt_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.dt_proj"
,
weights
,
bias
=
True
)
self
.
dt_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.dt_proj"
,
weights
,
bias
=
True
)
self
.
dt_proj_no_bias
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.dt_proj"
,
weights
,
bias
=
False
)
self
.
dt_proj_no_bias
=
FastLinear
.
load
(
self
.
out_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.out_proj"
,
weights
,
bias
=
False
)
config
,
f
"
{
prefix
}
.dt_proj"
,
weights
,
bias
=
False
)
self
.
out_proj
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.out_proj"
,
weights
,
bias
=
False
)
self
.
conv1d
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.conv1d"
,
weights
,
bias
=
True
)
self
.
conv1d
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.conv1d"
,
weights
,
bias
=
True
)
self
.
negA
=
-
torch
.
exp
(
weights
.
get_tensor
(
f
"
{
prefix
}
.A_log"
).
float
())
self
.
negA
=
-
torch
.
exp
(
weights
.
get_tensor
(
f
"
{
prefix
}
.A_log"
).
float
())
self
.
D
=
weights
.
get_tensor
(
f
"
{
prefix
}
.D"
)
self
.
D
=
weights
.
get_tensor
(
f
"
{
prefix
}
.D"
)
self
.
activation
=
"silu"
self
.
activation
=
"silu"
self
.
dt_rank
=
config
.
dt_rank
self
.
dt_rank
=
config
.
dt_rank
...
@@ -80,12 +86,14 @@ class MambaBlock(nn.Module):
...
@@ -80,12 +86,14 @@ class MambaBlock(nn.Module):
out
,
conv_state
,
ssm_state
=
self
.
step
(
hidden_states
,
conv_state
,
ssm_state
)
out
,
conv_state
,
ssm_state
=
self
.
step
(
hidden_states
,
conv_state
,
ssm_state
)
return
out
,
conv_state
,
ssm_state
return
out
,
conv_state
,
ssm_state
projected_states
=
self
.
in_proj
(
hidden_states
).
transpose
(
1
,
2
)
projected_states
=
self
.
in_proj
(
hidden_states
).
transpose
(
1
,
2
)
x
,
z
=
projected_states
.
chunk
(
2
,
dim
=
1
)
x
,
z
=
projected_states
.
chunk
(
2
,
dim
=
1
)
conv_state
=
F
.
pad
(
x
,
(
self
.
d_conv
-
seqlen
,
0
))
conv_state
=
F
.
pad
(
x
,
(
self
.
d_conv
-
seqlen
,
0
))
x
=
causal_conv1d_fn
(
x
=
causal_conv1d_fn
(
x
=
x
,
x
=
x
,
weight
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
)),
weight
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
)
),
bias
=
self
.
conv1d
.
bias
,
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
activation
=
self
.
activation
,
)
)
...
@@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
...
@@ -94,7 +102,9 @@ class MambaBlock(nn.Module):
# We want dt to have d as the slowest moving dimension
# We want dt to have d as the slowest moving dimension
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
x_dbl
=
self
.
x_proj
(
rearrange
(
x
,
"b d l -> (b l) d"
))
# (bl d)
x_dbl
=
self
.
x_proj
(
rearrange
(
x
,
"b d l -> (b l) d"
))
# (bl d)
dt
,
B
,
C
=
torch
.
split
(
x_dbl
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
dt
,
B
,
C
=
torch
.
split
(
x_dbl
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
dt
=
self
.
dt_proj
.
weight
@
dt
.
t
()
dt
=
self
.
dt_proj
.
weight
@
dt
.
t
()
dt
=
rearrange
(
dt
,
"d (b l) -> b d l"
,
l
=
seqlen
)
dt
=
rearrange
(
dt
,
"d (b l) -> b d l"
,
l
=
seqlen
)
B
=
rearrange
(
B
,
"(b l) dstate -> b dstate l"
,
l
=
seqlen
).
contiguous
()
B
=
rearrange
(
B
,
"(b l) dstate -> b dstate l"
,
l
=
seqlen
).
contiguous
()
...
@@ -118,28 +128,39 @@ class MambaBlock(nn.Module):
...
@@ -118,28 +128,39 @@ class MambaBlock(nn.Module):
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
def
step
(
self
,
hidden_states
,
conv_state
,
ssm_state
):
_xz
=
self
.
in_proj
(
hidden_states
)
_xz
=
self
.
in_proj
(
hidden_states
)
_x
,
_z
=
_xz
.
chunk
(
2
,
dim
=-
1
)
# (B D)
_x
,
_z
=
_xz
.
chunk
(
2
,
dim
=-
1
)
# (B D)
conv_state_new
=
torch
.
cat
([
conv_state
,
_x
.
transpose
(
1
,
2
)],
dim
=-
1
)
conv_state_new
=
torch
.
cat
([
conv_state
,
_x
.
transpose
(
1
,
2
)],
dim
=-
1
)
conv_out
=
causal_conv1d_fn
(
conv_out
=
causal_conv1d_fn
(
x
=
conv_state_new
,
x
=
conv_state_new
,
weight
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
)),
weight
=
self
.
conv1d
.
weight
.
view
(
bias
=
self
.
conv1d
.
bias
,
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
)
activation
=
self
.
activation
),
bias
=
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
)
)
conv_state
=
conv_state_new
[:,
:,
1
:]
conv_state
=
conv_state_new
[:,
:,
1
:]
bsz
,
seqlen
,
dim
=
hidden_states
.
shape
bsz
,
seqlen
,
dim
=
hidden_states
.
shape
output_tensor
=
torch
.
zeros
(
output_tensor
=
torch
.
zeros
(
(
bsz
,
seqlen
,
dim
),
(
bsz
,
seqlen
,
dim
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
for
i
in
range
(
0
,
bsz
):
for
i
in
range
(
0
,
bsz
):
x
=
conv_out
[
i
:
i
+
1
,:,
-
1
]
x
=
conv_out
[
i
:
i
+
1
,
:,
-
1
]
z
=
_z
[
i
:
i
+
1
,
-
1
,
:]
z
=
_z
[
i
:
i
+
1
,
-
1
,
:]
x_db
=
self
.
x_proj
(
x
)
x_db
=
self
.
x_proj
(
x
)
dt
,
B
,
C
=
torch
.
split
(
x_db
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
dt
,
B
,
C
=
torch
.
split
(
x_db
,
[
self
.
dt_rank
,
self
.
d_state
,
self
.
d_state
],
dim
=-
1
)
dt
=
F
.
linear
(
dt
,
self
.
dt_proj
.
weight
)
dt
=
F
.
linear
(
dt
,
self
.
dt_proj
.
weight
)
y
=
selective_state_update
(
y
=
selective_state_update
(
ssm_state
[
i
:
i
+
1
,:,:],
x
,
dt
,
self
.
negA
,
B
,
C
,
self
.
D
,
z
=
z
,
dt_bias
=
self
.
dt_proj
.
bias
,
dt_softplus
=
True
ssm_state
[
i
:
i
+
1
,
:,
:],
x
,
dt
,
self
.
negA
,
B
,
C
,
self
.
D
,
z
=
z
,
dt_bias
=
self
.
dt_proj
.
bias
,
dt_softplus
=
True
,
)
)
out
=
self
.
out_proj
(
y
)
out
=
self
.
out_proj
(
y
)
output_tensor
[
i
]
=
out
output_tensor
[
i
]
=
out
...
@@ -147,48 +168,70 @@ class MambaBlock(nn.Module):
...
@@ -147,48 +168,70 @@ class MambaBlock(nn.Module):
return
output_tensor
,
conv_state
,
ssm_state
return
output_tensor
,
conv_state
,
ssm_state
class
ResidualBlock
(
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
self
.
mamba_block
=
MambaBlock
(
prefix
=
f
"
{
layer_id
}
.mixer"
,
config
=
config
,
weights
=
weights
)
self
.
mamba_block
=
MambaBlock
(
self
.
layer_norm
=
FastRMSNorm
.
load
(
prefix
=
f
"
{
layer_id
}
.norm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
)
prefix
=
f
"
{
layer_id
}
.mixer"
,
config
=
config
,
weights
=
weights
)
self
.
layer_norm
=
FastRMSNorm
.
load
(
prefix
=
f
"
{
layer_id
}
.norm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
inference_params
:
Optional
[
Any
]
=
None
,
inference_params
:
Optional
[
Any
]
=
None
,
):
):
residual
=
(
hidden_states
+
residual
)
if
residual
is
not
None
else
hidden_states
residual
=
(
hidden_states
+
residual
)
if
residual
is
not
None
else
hidden_states
shape
=
residual
.
shape
shape
=
residual
.
shape
hidden_states
,
_
=
self
.
layer_norm
(
residual
.
view
(
-
1
,
shape
[
-
1
]))
hidden_states
,
_
=
self
.
layer_norm
(
residual
.
view
(
-
1
,
shape
[
-
1
]))
hidden_states
,
conv_state
,
last_ssm_state
=
self
.
mamba_block
(
hidden_states
.
view
(
*
shape
),
inference_params
)
hidden_states
,
conv_state
,
last_ssm_state
=
self
.
mamba_block
(
hidden_states
.
view
(
*
shape
),
inference_params
)
return
hidden_states
,
residual
,
conv_state
,
last_ssm_state
return
hidden_states
,
residual
,
conv_state
,
last_ssm_state
class
MambaModel
(
nn
.
Module
):
class
MambaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
prefix
=
"backbone"
prefix
=
"backbone"
self
.
embed_tokens
=
TensorParallelEmbedding
(
f
"
{
prefix
}
.embedding"
,
weights
)
self
.
embed_tokens
=
TensorParallelEmbedding
(
f
"
{
prefix
}
.embedding"
,
weights
)
self
.
blocks
=
nn
.
ModuleList
(
self
.
blocks
=
nn
.
ModuleList
(
[
ResidualBlock
(
f
"
{
prefix
}
.layers.
{
i
}
"
,
config
,
weights
)
for
i
in
range
(
config
.
n_layer
)]
[
ResidualBlock
(
f
"
{
prefix
}
.layers.
{
i
}
"
,
config
,
weights
)
for
i
in
range
(
config
.
n_layer
)
]
)
self
.
norm_f
=
FastRMSNorm
.
load
(
f
"
{
prefix
}
.norm_f"
,
weights
,
eps
=
config
.
layer_norm_epsilon
)
self
.
lm_head
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.embedding"
,
weights
,
bias
=
False
)
)
self
.
norm_f
=
FastRMSNorm
.
load
(
f
"
{
prefix
}
.norm_f"
,
weights
,
eps
=
config
.
layer_norm_epsilon
)
self
.
lm_head
=
FastLinear
.
load
(
config
,
f
"
{
prefix
}
.embedding"
,
weights
,
bias
=
False
)
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
inference_params
=
None
,
residual
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InferenceParams
]:
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
inference_params
=
None
,
residual
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InferenceParams
]:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
hidden_states
,
residual
,
conv_state
,
ssm_state
=
block
(
hidden_states
,
residual
,
inference_params
)
hidden_states
,
residual
,
conv_state
,
ssm_state
=
block
(
inference_params
.
key_value_memory_dict
[
block
.
mamba_block
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
hidden_states
,
residual
,
inference_params
)
inference_params
.
key_value_memory_dict
[
block
.
mamba_block
.
layer_idx
]
=
(
conv_state
,
ssm_state
,
)
hidden_states
=
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
hidden_states
=
(
hidden_states
+
residual
if
residual
is
not
None
else
hidden_states
)
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
-
1
)))
hidden_states
,
_
=
self
.
norm_f
(
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
-
1
)))
hidden_states
=
hidden_states
.
view
(
residual
.
shape
)
hidden_states
=
hidden_states
.
view
(
residual
.
shape
)
logits
=
self
.
lm_head
(
hidden_states
)
logits
=
self
.
lm_head
(
hidden_states
)
# update the offset for the next inference using these params
# update the offset for the next inference using these params
inference_params
.
seqlen_offset
+=
input_ids
.
size
(
1
)
inference_params
.
seqlen_offset
+=
input_ids
.
size
(
1
)
return
logits
,
input_ids
,
inference_params
return
logits
,
input_ids
,
inference_params
\ No newline at end of file
server/text_generation_server/models/flash_causal_lm.py
View file @
09b7c26b
...
@@ -842,7 +842,6 @@ class FlashCausalLM(Model):
...
@@ -842,7 +842,6 @@ class FlashCausalLM(Model):
else
:
else
:
next_token_logits
=
out
next_token_logits
=
out
speculate
=
get_speculate
()
speculate
=
get_speculate
()
(
(
next_input_ids
,
next_input_ids
,
...
@@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
...
@@ -1064,14 +1063,17 @@ class FlashCausalLM(Model):
if
top_n_tokens
>
0
:
if
top_n_tokens
>
0
:
all_top_tokens
=
[]
all_top_tokens
=
[]
for
(
top_token_ids
,
top_token_logprobs
)
in
zip
(
top_token_ids
,
top_token_logprobs
):
for
(
top_token_ids
,
top_token_logprobs
)
in
zip
(
top_token_ids
,
top_token_logprobs
):
toptoken_texts
=
self
.
tokenizer
.
batch_decode
(
toptoken_texts
=
self
.
tokenizer
.
batch_decode
(
top_token_ids
,
top_token_ids
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
skip_special_tokens
=
False
,
)
)
special_toptokens
=
[
special_toptokens
=
[
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
]
]
top_tokens
=
Tokens
(
top_tokens
=
Tokens
(
top_token_ids
,
top_token_ids
,
...
...
server/text_generation_server/models/mamba.py
View file @
09b7c26b
...
@@ -26,6 +26,7 @@ from dataclasses import dataclass
...
@@ -26,6 +26,7 @@ from dataclasses import dataclass
from
text_generation_server.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
from
text_generation_server.utils
import
NextTokenChooser
,
StoppingCriteria
,
Sampling
from
mamba_ssm.utils.generation
import
InferenceParams
from
mamba_ssm.utils.generation
import
InferenceParams
@
dataclass
@
dataclass
class
MambaBatch
(
Batch
):
class
MambaBatch
(
Batch
):
batch_id
:
int
batch_id
:
int
...
@@ -69,7 +70,7 @@ class MambaBatch(Batch):
...
@@ -69,7 +70,7 @@ class MambaBatch(Batch):
size
=
len
(
self
),
size
=
len
(
self
),
max_tokens
=
self
.
max_tokens
,
max_tokens
=
self
.
max_tokens
,
)
)
@
classmethod
@
classmethod
def
from_pb
(
def
from_pb
(
cls
,
cls
,
...
@@ -196,7 +197,7 @@ class MambaBatch(Batch):
...
@@ -196,7 +197,7 @@ class MambaBatch(Batch):
new_padding_right_offset
=
max
(
new_padding_right_offset
=
max
(
new_padding_right_offset
,
remaining_decode_tokens
new_padding_right_offset
,
remaining_decode_tokens
)
)
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
# Apply indices to input_ids, attention mask, past key values and other items that need to be cached
input_ids
=
self
.
input_ids
[
keep_indices
]
input_ids
=
self
.
input_ids
[
keep_indices
]
...
@@ -218,10 +219,13 @@ class MambaBatch(Batch):
...
@@ -218,10 +219,13 @@ class MambaBatch(Batch):
self
.
padding_right_offset
=
new_padding_right_offset
self
.
padding_right_offset
=
new_padding_right_offset
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
# TODO
# TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict
=
{}
key_value_memory_dict
=
{}
for
i
,
(
conv_state
,
ssm_state
)
in
self
.
inference_params
.
key_value_memory_dict
.
items
():
for
i
,
(
conv_state
,
ssm_state
,
)
in
self
.
inference_params
.
key_value_memory_dict
.
items
():
key_value_memory_dict
[
i
]
=
(
conv_state
[
indices
],
ssm_state
[
indices
])
key_value_memory_dict
[
i
]
=
(
conv_state
[
indices
],
ssm_state
[
indices
])
self
.
inference_params
.
key_value_memory_dict
=
key_value_memory_dict
self
.
inference_params
.
key_value_memory_dict
=
key_value_memory_dict
...
@@ -305,8 +309,9 @@ class MambaBatch(Batch):
...
@@ -305,8 +309,9 @@ class MambaBatch(Batch):
start_index
=
end_index
start_index
=
end_index
(
_
,
d_model
,
d_conv
)
=
(
(
_
,
d_model
,
d_conv
)
=
batches
[
0
].
inference_params
.
key_value_memory_dict
[
0
][
0
].
shape
batches
[
0
].
inference_params
.
key_value_memory_dict
[
0
][
0
].
shape
)
(
_
,
_
,
d_state
)
=
batches
[
0
].
inference_params
.
key_value_memory_dict
[
0
][
1
].
shape
(
_
,
_
,
d_state
)
=
batches
[
0
].
inference_params
.
key_value_memory_dict
[
0
][
1
].
shape
n_blocks
=
len
(
batches
[
0
].
inference_params
.
key_value_memory_dict
)
n_blocks
=
len
(
batches
[
0
].
inference_params
.
key_value_memory_dict
)
dtype
=
batches
[
0
].
inference_params
.
key_value_memory_dict
[
0
][
0
].
dtype
dtype
=
batches
[
0
].
inference_params
.
key_value_memory_dict
[
0
][
0
].
dtype
...
@@ -344,9 +349,15 @@ class MambaBatch(Batch):
...
@@ -344,9 +349,15 @@ class MambaBatch(Batch):
for
i
in
range
(
n_blocks
):
for
i
in
range
(
n_blocks
):
conv_state
,
ssm_state
=
batch
.
inference_params
.
key_value_memory_dict
[
i
]
conv_state
,
ssm_state
=
batch
.
inference_params
.
key_value_memory_dict
[
i
]
batch_size
=
batch
.
inference_params
.
max_batch_size
batch_size
=
batch
.
inference_params
.
max_batch_size
inference_params
.
key_value_memory_dict
[
i
][
0
][
current_batch
:
current_batch
+
batch_size
]
=
conv_state
inference_params
.
key_value_memory_dict
[
i
][
0
][
inference_params
.
key_value_memory_dict
[
i
][
1
][
current_batch
:
current_batch
+
batch_size
]
=
ssm_state
current_batch
:
current_batch
+
batch_size
inference_params
.
lengths_per_sample
[
current_batch
:
current_batch
+
batch_size
]
=
batch
.
inference_params
.
lengths_per_sample
]
=
conv_state
inference_params
.
key_value_memory_dict
[
i
][
1
][
current_batch
:
current_batch
+
batch_size
]
=
ssm_state
inference_params
.
lengths_per_sample
[
current_batch
:
current_batch
+
batch_size
]
=
batch
.
inference_params
.
lengths_per_sample
current_batch
+=
batch_size
current_batch
+=
batch_size
return
cls
(
return
cls
(
...
@@ -366,12 +377,13 @@ class MambaBatch(Batch):
...
@@ -366,12 +377,13 @@ class MambaBatch(Batch):
padding_right_offset
=
padding_right_offset
,
padding_right_offset
=
padding_right_offset
,
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
inference_params
=
inference_params
inference_params
=
inference_params
,
)
)
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
requests
)
return
len
(
self
.
requests
)
class
Mamba
(
Model
):
class
Mamba
(
Model
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -428,7 +440,7 @@ class Mamba(Model):
...
@@ -428,7 +440,7 @@ class Mamba(Model):
def
warmup
(
self
,
batch
)
->
Optional
[
int
]:
def
warmup
(
self
,
batch
)
->
Optional
[
int
]:
# TODO: implement warmup for Mamba if needed
# TODO: implement warmup for Mamba if needed
return
None
return
None
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -441,7 +453,9 @@ class Mamba(Model):
...
@@ -441,7 +453,9 @@ class Mamba(Model):
def
generate_token
(
self
,
batch
)
->
Tuple
[
List
[
Any
],
Optional
[
Any
],
Tuple
[
int
,
int
]]:
def
generate_token
(
self
,
batch
)
->
Tuple
[
List
[
Any
],
Optional
[
Any
],
Tuple
[
int
,
int
]]:
start
=
time
.
time_ns
()
start
=
time
.
time_ns
()
input_ids
=
batch
.
input_ids
# batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
input_ids
=
(
batch
.
input_ids
)
# batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size
=
input_ids
.
shape
[
0
]
batch_size
=
input_ids
.
shape
[
0
]
max_seqlen
=
input_ids
.
shape
[
1
]
max_seqlen
=
input_ids
.
shape
[
1
]
...
@@ -450,8 +464,11 @@ class Mamba(Model):
...
@@ -450,8 +464,11 @@ class Mamba(Model):
# Inference params
# Inference params
seqlen_og
=
0
seqlen_og
=
0
inf_cache
=
{}
inf_cache
=
{}
lengths_per_sample
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
)
*
max_seqlen
lengths_per_sample
=
(
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int32
,
device
=
input_ids
.
device
)
*
max_seqlen
)
if
batch
.
inference_params
is
None
:
if
batch
.
inference_params
is
None
:
inference_params
=
InferenceParams
(
inference_params
=
InferenceParams
(
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
...
@@ -478,11 +495,16 @@ class Mamba(Model):
...
@@ -478,11 +495,16 @@ class Mamba(Model):
device
=
block
.
dt_proj
.
weight
.
device
,
device
=
block
.
dt_proj
.
weight
.
device
,
dtype
=
block
.
dt_proj
.
weight
.
dtype
,
dtype
=
block
.
dt_proj
.
weight
.
dtype
,
)
)
inference_params
.
key_value_memory_dict
[
block
.
layer_idx
]
=
(
conv_state
,
ssm_state
)
inference_params
.
key_value_memory_dict
[
block
.
layer_idx
]
=
(
conv_state
,
ssm_state
,
)
batch
.
inference_params
=
inference_params
batch
.
inference_params
=
inference_params
# Forward pass
# Forward pass
logits
,
past_input_ids
,
new_inference_params
=
self
.
model
(
input_ids
,
batch
.
inference_params
)
logits
,
past_input_ids
,
new_inference_params
=
self
.
model
(
input_ids
,
batch
.
inference_params
)
batch
.
inference_params
=
new_inference_params
batch
.
inference_params
=
new_inference_params
# Results
# Results
...
@@ -564,7 +586,8 @@ class Mamba(Model):
...
@@ -564,7 +586,8 @@ class Mamba(Model):
prefix_offset
=
len
(
all_input_ids
)
prefix_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
-
stopping_criteria
.
current_tokens
-
1
,
-
1
,
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
read_offset
=
len
(
all_input_ids
)
-
stopping_criteria
.
current_tokens
,
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
)
)
# Get seed
# Get seed
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
09b7c26b
...
@@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
...
@@ -750,14 +750,17 @@ class Seq2SeqLM(Model):
if
top_n_tokens
>
0
:
if
top_n_tokens
>
0
:
all_top_tokens
=
[]
all_top_tokens
=
[]
for
(
top_token_ids
,
top_token_logprobs
)
in
zip
(
top_token_ids
,
top_token_logprobs
):
for
(
top_token_ids
,
top_token_logprobs
)
in
zip
(
top_token_ids
,
top_token_logprobs
):
toptoken_texts
=
self
.
tokenizer
.
batch_decode
(
toptoken_texts
=
self
.
tokenizer
.
batch_decode
(
top_token_ids
,
top_token_ids
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
skip_special_tokens
=
False
,
skip_special_tokens
=
False
,
)
)
special_toptokens
=
[
special_toptokens
=
[
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
token_id
in
self
.
all_special_ids
for
token_id
in
top_token_ids
]
]
top_tokens
=
Tokens
(
top_tokens
=
Tokens
(
top_token_ids
,
top_token_ids
,
...
...
server/text_generation_server/models/types.py
View file @
09b7c26b
...
@@ -95,5 +95,7 @@ class Generation:
...
@@ -95,5 +95,7 @@ class Generation:
generated_text
=
self
.
generated_text
.
to_pb
()
generated_text
=
self
.
generated_text
.
to_pb
()
if
self
.
generated_text
is
not
None
if
self
.
generated_text
is
not
None
else
None
,
else
None
,
top_tokens
=
[
top_tokens
.
to_pb
()
for
top_tokens
in
self
.
top_tokens
]
if
self
.
top_tokens
is
not
None
else
None
,
top_tokens
=
[
top_tokens
.
to_pb
()
for
top_tokens
in
self
.
top_tokens
]
if
self
.
top_tokens
is
not
None
else
None
,
)
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment