Unverified Commit f5d43414 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing types. (#1906)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent d8402eaf
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
use crate::validation::{Validation, ValidationError}; use crate::validation::{Validation, ValidationError};
use crate::{ use crate::{
ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse, ChatTemplateInputs, ChatTemplateVersions, Entry, GenerateRequest, GenerateStreamResponse,
HubTokenizerConfig, Message, PrefillToken, Queue, Token, HubTokenizerConfig, Message, MessageChunk, PrefillToken, Queue, Text, TextMessage, Token,
}; };
use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools}; use crate::{FunctionRef, FunctionsMap, GrammarType, Properties, Tool, ToolType, Tools};
use futures::future::try_join_all; use futures::future::try_join_all;
...@@ -362,16 +362,15 @@ impl ChatTemplate { ...@@ -362,16 +362,15 @@ impl ChatTemplate {
if self.use_default_tool_template { if self.use_default_tool_template {
if let Some(last_message) = messages.last_mut() { if let Some(last_message) = messages.last_mut() {
if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt { if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
last_message.content = Some(format!( last_message.content.push(MessageChunk::Text(Text {
"{}\n---\n{}\n{}", text: format!("\n---\n{}\n{}", tool_prompt, tools),
last_message.content.as_deref().unwrap_or_default(), }));
tool_prompt,
tools
));
} }
} }
} }
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
self.template self.template
.render(ChatTemplateInputs { .render(ChatTemplateInputs {
messages, messages,
...@@ -939,8 +938,7 @@ impl InferError { ...@@ -939,8 +938,7 @@ impl InferError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::infer::raise_exception; use crate::infer::raise_exception;
use crate::ChatTemplateInputs; use crate::{ChatTemplateInputs, TextMessage};
use crate::Message;
use minijinja::Environment; use minijinja::Environment;
#[test] #[test]
...@@ -974,33 +972,21 @@ mod tests { ...@@ -974,33 +972,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
...@@ -1048,40 +1034,25 @@ mod tests { ...@@ -1048,40 +1034,25 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi again!".to_string()), content: "Hi again!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
...@@ -1134,33 +1105,21 @@ mod tests { ...@@ -1134,33 +1105,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
...@@ -1197,33 +1156,21 @@ mod tests { ...@@ -1197,33 +1156,21 @@ mod tests {
let chat_template_inputs = ChatTemplateInputs { let chat_template_inputs = ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hi!".to_string()), content: "Hi!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("Hello how can I help?".to_string()), content: "Hello how can I help?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("What is Deep Learning?".to_string()), content: "What is Deep Learning?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("magic!".to_string()), content: "magic!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
bos_token: Some("[BOS]"), bos_token: Some("[BOS]"),
...@@ -1246,38 +1193,24 @@ mod tests { ...@@ -1246,38 +1193,24 @@ mod tests {
#[test] #[test]
fn test_many_chat_templates() { fn test_many_chat_templates() {
let example_chat = vec![ let example_chat = vec![
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("Hello, how are you?".to_string()), content: "Hello, how are you?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "assistant".to_string(), role: "assistant".to_string(),
content: Some("I'm doing great. How can I help you today?".to_string()), content: "I'm doing great. How can I help you today?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage {
role: "user".to_string(), role: "user".to_string(),
content: Some("I'd like to show off how chat templating works!".to_string()), content: "I'd like to show off how chat templating works!".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
]; ];
let example_chat_with_system = [Message { let example_chat_with_system = [TextMessage {
role: "system".to_string(), role: "system".to_string(),
content: Some( content: "You are a friendly chatbot who always responds in the style of a pirate"
"You are a friendly chatbot who always responds in the style of a pirate"
.to_string(), .to_string(),
),
name: None,
tool_calls: None,
tool_call_id: None,
}] }]
.iter() .iter()
.chain(&example_chat) .chain(&example_chat)
...@@ -1417,19 +1350,13 @@ mod tests { ...@@ -1417,19 +1350,13 @@ mod tests {
chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", chat_template: "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
input: ChatTemplateInputs { input: ChatTemplateInputs {
messages: vec![ messages: vec![
Message { TextMessage{
role: "system".to_string(), role: "system".to_string(),
content: Some("You are a friendly chatbot who always responds in the style of a pirate".to_string()), content: "You are a friendly chatbot who always responds in the style of a pirate".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
Message { TextMessage{
role: "user".to_string(), role: "user".to_string(),
content: Some("How many helicopters can a human eat in one sitting?".to_string()), content: "How many helicopters can a human eat in one sitting?".to_string(),
name: None,
tool_calls: None,
tool_call_id: None,
}, },
], ],
add_generation_prompt: true, add_generation_prompt: true,
......
...@@ -11,6 +11,7 @@ use queue::{Entry, Queue}; ...@@ -11,6 +11,7 @@ use queue::{Entry, Queue};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::OwnedSemaphorePermit; use tokio::sync::OwnedSemaphorePermit;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn;
use utoipa::ToSchema; use utoipa::ToSchema;
use validation::Validation; use validation::Validation;
...@@ -440,7 +441,7 @@ pub(crate) struct ChatCompletion { ...@@ -440,7 +441,7 @@ pub(crate) struct ChatCompletion {
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionComplete { pub(crate) struct ChatCompletionComplete {
pub index: u32, pub index: u32,
pub message: Message, pub message: OutputMessage,
pub logprobs: Option<ChatCompletionLogprobs>, pub logprobs: Option<ChatCompletionLogprobs>,
pub finish_reason: String, pub finish_reason: String,
} }
...@@ -533,6 +534,30 @@ impl ChatCompletion { ...@@ -533,6 +534,30 @@ impl ChatCompletion {
return_logprobs: bool, return_logprobs: bool,
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
) -> Self { ) -> Self {
let message = match (output, tool_calls) {
(Some(content), None) => OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content,
}),
(None, Some(tool_calls)) => OutputMessage::ToolCall(ToolCallMessage {
role: "assistant".to_string(),
tool_calls,
}),
(Some(output), Some(_)) => {
warn!("Received both chat and tool call");
OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content: output,
})
}
(None, None) => {
warn!("Didn't receive an answer");
OutputMessage::ChatMessage(TextMessage {
role: "assistant".into(),
content: "".to_string(),
})
}
};
Self { Self {
id: String::new(), id: String::new(),
object: "text_completion".into(), object: "text_completion".into(),
...@@ -541,13 +566,7 @@ impl ChatCompletion { ...@@ -541,13 +566,7 @@ impl ChatCompletion {
system_fingerprint, system_fingerprint,
choices: vec![ChatCompletionComplete { choices: vec![ChatCompletionComplete {
index: 0, index: 0,
message: Message { message,
role: "assistant".into(),
content: output,
name: None,
tool_calls,
tool_call_id: None,
},
logprobs: return_logprobs logprobs: return_logprobs
.then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))), .then(|| ChatCompletionLogprobs::from((details.tokens, details.top_tokens))),
finish_reason: details.finish_reason.to_string(), finish_reason: details.finish_reason.to_string(),
...@@ -569,6 +588,7 @@ pub(crate) struct CompletionCompleteChunk { ...@@ -569,6 +588,7 @@ pub(crate) struct CompletionCompleteChunk {
pub model: String, pub model: String,
pub system_fingerprint: String, pub system_fingerprint: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema)] #[derive(Clone, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionChunk { pub(crate) struct ChatCompletionChunk {
pub id: String, pub id: String,
...@@ -589,21 +609,20 @@ pub(crate) struct ChatCompletionChoice { ...@@ -589,21 +609,20 @@ pub(crate) struct ChatCompletionChoice {
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct ToolCallDelta {
#[schema(example = "assistant")]
role: String,
tool_calls: DeltaToolCall,
}
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema)]
pub(crate) struct ChatCompletionDelta { enum ChatCompletionDelta {
#[schema(example = "user")] Chat(TextMessage),
// TODO Modify this to a true enum. Tool(ToolCallDelta),
#[serde(default, skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "What is Deep Learning?")]
pub content: Option<String>,
// default to None
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<DeltaToolCall>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct DeltaToolCall { pub(crate) struct DeltaToolCall {
pub index: u32, pub index: u32,
pub id: String, pub id: String,
...@@ -611,7 +630,7 @@ pub(crate) struct DeltaToolCall { ...@@ -611,7 +630,7 @@ pub(crate) struct DeltaToolCall {
pub function: Function, pub function: Function,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Debug, PartialEq)]
pub(crate) struct Function { pub(crate) struct Function {
pub name: Option<String>, pub name: Option<String>,
pub arguments: String, pub arguments: String,
...@@ -629,15 +648,13 @@ impl ChatCompletionChunk { ...@@ -629,15 +648,13 @@ impl ChatCompletionChunk {
finish_reason: Option<String>, finish_reason: Option<String>,
) -> Self { ) -> Self {
let delta = match (delta, tool_calls) { let delta = match (delta, tool_calls) {
(Some(delta), _) => ChatCompletionDelta { (Some(delta), _) => ChatCompletionDelta::Chat(TextMessage {
role: Some("assistant".to_string()), role: "assistant".to_string(),
content: Some(delta), content: delta,
tool_calls: None, }),
}, (None, Some(tool_calls)) => ChatCompletionDelta::Tool(ToolCallDelta {
(None, Some(tool_calls)) => ChatCompletionDelta { role: "assistant".to_string(),
role: Some("assistant".to_string()), tool_calls: DeltaToolCall {
content: None,
tool_calls: Some(DeltaToolCall {
index: 0, index: 0,
id: String::new(), id: String::new(),
r#type: "function".to_string(), r#type: "function".to_string(),
...@@ -645,13 +662,12 @@ impl ChatCompletionChunk { ...@@ -645,13 +662,12 @@ impl ChatCompletionChunk {
name: None, name: None,
arguments: tool_calls[0].to_string(), arguments: tool_calls[0].to_string(),
}, },
}),
},
(None, None) => ChatCompletionDelta {
role: None,
content: None,
tool_calls: None,
}, },
}),
(None, None) => ChatCompletionDelta::Chat(TextMessage {
role: "assistant".to_string(),
content: "".to_string(),
}),
}; };
Self { Self {
id: String::new(), id: String::new(),
...@@ -852,7 +868,7 @@ where ...@@ -852,7 +868,7 @@ where
state.end() state.end()
} }
#[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default)] #[derive(Clone, Debug, Deserialize, Serialize, ToSchema, Default, PartialEq)]
pub(crate) struct FunctionDefinition { pub(crate) struct FunctionDefinition {
#[serde(default)] #[serde(default)]
pub description: Option<String>, pub description: Option<String>,
...@@ -872,7 +888,7 @@ pub(crate) struct Tool { ...@@ -872,7 +888,7 @@ pub(crate) struct Tool {
#[derive(Clone, Serialize, Deserialize, Default)] #[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> { pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<Message>, messages: Vec<TextMessage>,
bos_token: Option<&'a str>, bos_token: Option<&'a str>,
eos_token: Option<&'a str>, eos_token: Option<&'a str>,
add_generation_prompt: bool, add_generation_prompt: bool,
...@@ -880,91 +896,113 @@ pub(crate) struct ChatTemplateInputs<'a> { ...@@ -880,91 +896,113 @@ pub(crate) struct ChatTemplateInputs<'a> {
tools_prompt: Option<&'a str>, tools_prompt: Option<&'a str>,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug, PartialEq)]
pub(crate) struct ToolCall { pub(crate) struct ToolCall {
pub id: String, pub id: String,
pub r#type: String, pub r#type: String,
pub function: FunctionDefinition, pub function: FunctionDefinition,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct Text { struct Url {
#[serde(default)] url: String,
pub text: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct ImageUrl { struct ImageUrl {
#[serde(default)] image_url: Url,
pub url: String,
} }
#[derive(Clone, Deserialize, Serialize, ToSchema, Default, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct Content { struct Text {
pub r#type: String, text: String,
#[serde(default, skip_serializing_if = "Option::is_none")] }
pub text: Option<String>,
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
enum MessageChunk {
Text(Text),
ImageUrl(ImageUrl),
}
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct Message {
#[schema(example = "user")]
role: String,
#[schema(example = "My name is David and I")]
#[serde(deserialize_with = "message_content_serde::deserialize")]
content: Vec<MessageChunk>,
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub image_url: Option<ImageUrl>, #[schema(example = "\"David\"")]
name: Option<String>,
} }
mod message_content_serde { mod message_content_serde {
use super::*; use super::*;
use serde::de; use serde::{Deserialize, Deserializer};
use serde::Deserializer;
use serde_json::Value;
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<String>, D::Error> pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<MessageChunk>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let value = Value::deserialize(deserializer)?; #[derive(Deserialize)]
match value { #[serde(untagged)]
Value::String(s) => Ok(Some(s)), enum Message {
Value::Array(arr) => { Text(String),
let results: Result<Vec<String>, _> = arr Chunks(Vec<MessageChunk>),
.into_iter()
.map(|v| {
let content: Content =
serde_json::from_value(v).map_err(de::Error::custom)?;
match content.r#type.as_str() {
"text" => Ok(content.text.unwrap_or_default()),
"image_url" => {
if let Some(url) = content.image_url {
Ok(format!("![]({})", url.url))
} else {
Ok(String::new())
} }
let message: Message = Deserialize::deserialize(deserializer)?;
let chunks = match message {
Message::Text(text) => {
vec![MessageChunk::Text(Text { text })]
} }
_ => Err(de::Error::custom("invalid content type")), Message::Chunks(s) => s,
};
Ok(chunks)
} }
}) }
.collect();
results.map(|strings| Some(strings.join(""))) #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub struct TextMessage {
#[schema(example = "user")]
pub role: String,
#[schema(example = "My name is David and I")]
pub content: String,
}
impl From<Message> for TextMessage {
fn from(value: Message) -> Self {
TextMessage {
role: value.role,
content: value
.content
.into_iter()
.map(|c| match c {
MessageChunk::Text(Text { text }) => text,
MessageChunk::ImageUrl(image) => {
let url = image.image_url.url;
format!("![]({url})")
} }
Value::Null => Ok(None), })
_ => Err(de::Error::custom("invalid token format")), .collect::<Vec<_>>()
.join(""),
} }
} }
} }
#[derive(Clone, Deserialize, ToSchema, Serialize, Debug)] #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
pub(crate) struct Message { pub struct ToolCallMessage {
#[schema(example = "user")] #[schema(example = "assistant")]
pub role: String, role: String,
#[serde(skip_serializing_if = "Option::is_none")] tool_calls: Vec<ToolCall>,
#[schema(example = "My name is David and I")] }
#[serde(default, deserialize_with = "message_content_serde::deserialize")]
pub content: Option<String>, #[derive(Clone, Deserialize, ToSchema, Serialize, Debug, PartialEq)]
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(untagged)]
#[schema(example = "\"David\"")] pub(crate) enum OutputMessage {
pub name: Option<String>, ChatMessage(TextMessage),
#[serde(default, skip_serializing_if = "Option::is_none")] ToolCall(ToolCallMessage),
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
#[schema(example = "\"get_weather\"")]
pub tool_call_id: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, ToSchema)] #[derive(Clone, Debug, Deserialize, ToSchema)]
...@@ -1127,7 +1165,7 @@ pub(crate) struct ErrorResponse { ...@@ -1127,7 +1165,7 @@ pub(crate) struct ErrorResponse {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use serde_json::json;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
pub(crate) async fn get_tokenizer() -> Tokenizer { pub(crate) async fn get_tokenizer() -> Tokenizer {
...@@ -1195,4 +1233,66 @@ mod tests { ...@@ -1195,4 +1233,66 @@ mod tests {
); );
assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string())); assert_eq!(config.eos_token, Some("<|end▁of▁sentence|>".to_string()));
} }
#[test]
fn test_chat_simple_string() {
let json = json!(
{
"model": "",
"messages": [
{"role": "user",
"content": "What is Deep Learning?"
}
]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert_eq!(
request.messages[0],
Message {
role: "user".to_string(),
content: vec![MessageChunk::Text(Text {
text: "What is Deep Learning?".to_string()
}),],
name: None
}
);
}
#[test]
fn test_chat_request() {
let json = json!(
{
"model": "",
"messages": [
{"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png"
},
},
]
}
]
});
let request: ChatRequest = serde_json::from_str(json.to_string().as_str()).unwrap();
assert_eq!(
request.messages[0],
Message{
role: "user".to_string(),
content: vec![
MessageChunk::Text(Text { text: "Whats in this image?".to_string() }),
MessageChunk::ImageUrl(ImageUrl { image_url: Url { url: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png".to_string() } })
],
name: None
}
);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment