Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9209b209
Unverified
Commit
9209b209
authored
Sep 24, 2025
by
Chang Su
Committed by
GitHub
Sep 24, 2025
Browse files
router-grpc: Support jinja chat template content format detection (#10832)
parent
adba172f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1276 additions
and
353 deletions
+1276
-353
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+1
-1
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+348
-56
sgl-router/src/tokenizer/chat_template.rs
sgl-router/src/tokenizer/chat_template.rs
+277
-98
sgl-router/src/tokenizer/huggingface.rs
sgl-router/src/tokenizer/huggingface.rs
+36
-30
sgl-router/src/tokenizer/mod.rs
sgl-router/src/tokenizer/mod.rs
+0
-2
sgl-router/tests/chat_template_format_detection.rs
sgl-router/tests/chat_template_format_detection.rs
+238
-0
sgl-router/tests/chat_template_integration.rs
sgl-router/tests/chat_template_integration.rs
+314
-0
sgl-router/tests/chat_template_loading.rs
sgl-router/tests/chat_template_loading.rs
+62
-16
sgl-router/tests/test_chat_template.rs
sgl-router/tests/test_chat_template.rs
+0
-150
No files found.
sgl-router/Cargo.toml
View file @
9209b209
...
@@ -57,7 +57,7 @@ tokio-stream = { version = "0.1", features = ["sync"] }
...
@@ -57,7 +57,7 @@ tokio-stream = { version = "0.1", features = ["sync"] }
anyhow
=
"1.0"
anyhow
=
"1.0"
tokenizers
=
{
version
=
"0.22.0"
}
tokenizers
=
{
version
=
"0.22.0"
}
tiktoken-rs
=
{
version
=
"0.7.0"
}
tiktoken-rs
=
{
version
=
"0.7.0"
}
minijinja
=
{
version
=
"2.0"
}
minijinja
=
{
version
=
"2.0"
,
features
=
["unstable_machinery"]
}
rustls
=
{
version
=
"0.23"
,
default-features
=
false
,
features
=
[
"ring"
,
"std"
]
}
rustls
=
{
version
=
"0.23"
,
default-features
=
false
,
features
=
[
"ring"
,
"std"
]
}
hf-hub
=
{
version
=
"0.4.3"
,
features
=
["tokio"]
}
hf-hub
=
{
version
=
"0.4.3"
,
features
=
["tokio"]
}
rmcp
=
{
version
=
"0.6.3"
,
features
=
[
"client"
,
"server"
,
rmcp
=
{
version
=
"0.6.3"
,
features
=
[
"client"
,
"server"
,
...
...
sgl-router/src/routers/grpc/router.rs
View file @
9209b209
// gRPC Router Implementation
// gRPC Router Implementation
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
async_trait
::
async_trait
;
use
axum
::{
body
::
Body
,
extract
::
Request
,
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
config
::
types
::
RetryConfig
;
use
crate
::
core
::{
use
crate
::
core
::{
BasicWorkerBuilder
,
CircuitBreakerConfig
,
HealthConfig
,
WorkerRegistry
,
WorkerType
,
BasicWorkerBuilder
,
CircuitBreakerConfig
,
HealthConfig
,
WorkerRegistry
,
WorkerType
,
...
@@ -7,27 +20,16 @@ use crate::core::{
...
@@ -7,27 +20,16 @@ use crate::core::{
use
crate
::
grpc
::{
proto
,
SglangSchedulerClient
};
use
crate
::
grpc
::{
proto
,
SglangSchedulerClient
};
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::{
LoadBalancingPolicy
,
PolicyRegistry
};
use
crate
::
policies
::{
LoadBalancingPolicy
,
PolicyRegistry
};
use
crate
::
protocols
::
spec
::{
use
crate
::
protocols
::
spec
::{
ChatCompletionRequest
,
ResponseFormat
,
StringOrArray
};
ChatCompletionRequest
,
ChatMessage
,
ContentPart
,
ResponseFormat
,
StringOrArray
,
UserMessageContent
,
};
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
routers
::
RouterTrait
;
use
crate
::
tokenizer
::
{
chat_template
::
ChatMessage
as
TokenizerChatMessage
,
traits
::
Tokenizer
}
;
use
crate
::
tokenizer
::
traits
::
Tokenizer
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
crate
::
tool_parser
::
ParserRegistry
;
use
async_trait
::
async_trait
;
use
axum
::{
body
::
Body
,
extract
::
Request
,
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
};
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
Arc
;
use
std
::
time
::
Duration
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
uuid
::
Uuid
;
use
uuid
::
Uuid
;
use
crate
::
tokenizer
::
chat_template
::
ChatTemplateContentFormat
;
use
serde_json
::
Value
;
// Data structures for processing
// Data structures for processing
#[derive(Debug)]
#[derive(Debug)]
pub
struct
ProcessedMessages
{
pub
struct
ProcessedMessages
{
...
@@ -290,16 +292,19 @@ impl GrpcRouter {
...
@@ -290,16 +292,19 @@ impl GrpcRouter {
&
self
,
&
self
,
request
:
&
ChatCompletionRequest
,
request
:
&
ChatCompletionRequest
,
)
->
Result
<
ProcessedMessages
,
String
>
{
)
->
Result
<
ProcessedMessages
,
String
>
{
let
tokenizer_messages
=
self
.convert_messages_for_tokenizer
(
&
request
.messages
)
?
;
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
let
formatted_text
=
if
let
Some
(
hf_tokenizer
)
=
let
formatted_text
=
if
let
Some
(
hf_tokenizer
)
=
self
.tokenizer
self
.tokenizer
.as_any
()
.as_any
()
.downcast_ref
::
<
crate
::
tokenizer
::
HuggingFaceTokenizer
>
()
.downcast_ref
::
<
crate
::
tokenizer
::
HuggingFaceTokenizer
>
()
{
{
// Get content format and transform messages accordingly
let
content_format
=
hf_tokenizer
.chat_template_content_format
();
let
transformed_messages
=
Self
::
transform_messages_for_content_format
(
&
request
.messages
,
content_format
)
?
;
hf_tokenizer
hf_tokenizer
.apply_chat_template
(
&
t
okenizer
_messages
,
true
)
.apply_chat_template
(
&
t
ransformed
_messages
,
true
)
.map_err
(|
e
|
format!
(
"Failed to apply chat template: {}"
,
e
))
?
.map_err
(|
e
|
format!
(
"Failed to apply chat template: {}"
,
e
))
?
}
else
{
}
else
{
return
Err
(
return
Err
(
...
@@ -317,46 +322,76 @@ impl GrpcRouter {
...
@@ -317,46 +322,76 @@ impl GrpcRouter {
})
})
}
}
/// Convert spec ChatMessage enum to tokenizer ChatMessage struct
/// Transform messages based on content format for ANY message type
fn
convert_messages_for_tokenizer
(
fn
transform_messages_for_content_format
(
&
self
,
messages
:
&
[
crate
::
protocols
::
spec
::
ChatMessage
],
messages
:
&
[
ChatMessage
],
content_format
:
crate
::
tokenizer
::
chat_template
::
ChatTemplateContentFormat
,
)
->
Result
<
Vec
<
TokenizerChatMessage
>
,
String
>
{
)
->
Result
<
Vec
<
serde_json
::
Value
>
,
String
>
{
let
mut
converted
=
Vec
::
new
();
messages
.iter
()
for
message
in
messages
{
.map
(|
message
|
{
let
tokenizer_msg
=
match
message
{
let
mut
message_json
=
serde_json
::
to_value
(
message
)
ChatMessage
::
System
{
content
,
..
}
=>
TokenizerChatMessage
::
new
(
"system"
,
content
),
.map_err
(|
e
|
format!
(
"Failed to serialize message: {}"
,
e
))
?
;
ChatMessage
::
User
{
content
,
..
}
=>
{
let
text_content
=
match
content
{
if
let
Some
(
obj
)
=
message_json
.as_object_mut
()
{
UserMessageContent
::
Text
(
text
)
=>
text
.clone
(),
if
let
Some
(
content_value
)
=
obj
.get_mut
(
"content"
)
{
UserMessageContent
::
Parts
(
parts
)
=>
{
Self
::
transform_content_field
(
content_value
,
content_format
);
// Simple text extraction for now - multimodal is placeholder
}
parts
.iter
()
.filter_map
(|
part
|
match
part
{
ContentPart
::
Text
{
text
}
=>
Some
(
text
.as_str
()),
ContentPart
::
ImageUrl
{
..
}
=>
None
,
// Skip images for now
})
.collect
::
<
Vec
<&
str
>>
()
.join
(
" "
)
}
};
TokenizerChatMessage
::
new
(
"user"
,
text_content
)
}
ChatMessage
::
Assistant
{
content
,
..
}
=>
{
// Simple content extraction - no special tool/reasoning formatting
TokenizerChatMessage
::
new
(
"assistant"
,
content
.as_deref
()
.unwrap_or
(
""
))
}
}
ChatMessage
::
Tool
{
content
,
..
}
=>
TokenizerChatMessage
::
new
(
"tool"
,
content
),
ChatMessage
::
Function
{
content
,
..
}
=>
{
Ok
(
message_json
)
TokenizerChatMessage
::
new
(
"function"
,
content
)
})
.collect
()
}
/// Transform a single content field based on content format
fn
transform_content_field
(
content_value
:
&
mut
Value
,
content_format
:
ChatTemplateContentFormat
,
)
{
let
Some
(
content_array
)
=
content_value
.as_array
()
else
{
return
;
// Not multimodal, keep as-is
};
match
content_format
{
ChatTemplateContentFormat
::
String
=>
{
// Extract and join text parts only
let
text_parts
:
Vec
<
String
>
=
content_array
.iter
()
.filter_map
(|
part
|
{
part
.as_object
()
?
.get
(
"type"
)
?
.as_str
()
.filter
(|
&
t
|
t
==
"text"
)
.and_then
(|
_
|
part
.as_object
()
?
.get
(
"text"
)
?
.as_str
())
.map
(
String
::
from
)
})
.collect
();
if
!
text_parts
.is_empty
()
{
*
content_value
=
Value
::
String
(
text_parts
.join
(
" "
));
}
}
};
}
converted
.push
(
tokenizer_msg
);
ChatTemplateContentFormat
::
OpenAI
=>
{
}
// Replace media URLs with simple type placeholders
let
processed_parts
:
Vec
<
Value
>
=
content_array
.iter
()
.map
(|
part
|
{
part
.as_object
()
.and_then
(|
obj
|
obj
.get
(
"type"
)
?
.as_str
())
.and_then
(|
type_str
|
match
type_str
{
"image_url"
=>
Some
(
serde_json
::
json!
({
"type"
:
"image"
})),
"video_url"
=>
Some
(
serde_json
::
json!
({
"type"
:
"video"
})),
"audio_url"
=>
Some
(
serde_json
::
json!
({
"type"
:
"audio"
})),
_
=>
None
,
})
.unwrap_or_else
(||
part
.clone
())
})
.collect
();
Ok
(
converted
)
*
content_value
=
Value
::
Array
(
processed_parts
);
}
}
}
}
/// Build gRPC SamplingParams from OpenAI request
/// Build gRPC SamplingParams from OpenAI request
...
@@ -636,3 +671,260 @@ impl RouterTrait for GrpcRouter {
...
@@ -636,3 +671,260 @@ impl RouterTrait for GrpcRouter {
(
StatusCode
::
SERVICE_UNAVAILABLE
)
.into_response
()
(
StatusCode
::
SERVICE_UNAVAILABLE
)
.into_response
()
}
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
protocols
::
spec
::{
ChatMessage
,
ContentPart
,
ImageUrl
,
UserMessageContent
};
use
crate
::
tokenizer
::
chat_template
::
ChatTemplateContentFormat
;
use
serde_json
::
json
;
#[test]
fn
test_transform_messages_string_format
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
Text
{
text
:
"Hello"
.to_string
(),
},
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
None
,
},
},
ContentPart
::
Text
{
text
:
"World"
.to_string
(),
},
]),
name
:
None
,
}];
let
result
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
String
,
)
.unwrap
();
assert_eq!
(
result
.len
(),
1
);
let
transformed_message
=
&
result
[
0
];
// Should flatten multimodal content to text only
assert_eq!
(
transformed_message
[
"content"
]
.as_str
()
.unwrap
(),
"Hello World"
);
assert_eq!
(
transformed_message
[
"role"
]
.as_str
()
.unwrap
(),
"user"
);
}
#[test]
fn
test_transform_messages_openai_format
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
Text
{
text
:
"Describe this image:"
.to_string
(),
},
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
Some
(
"high"
.to_string
()),
},
},
]),
name
:
None
,
}];
let
result
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
OpenAI
,
)
.unwrap
();
assert_eq!
(
result
.len
(),
1
);
let
transformed_message
=
&
result
[
0
];
// Should replace media URLs with simple type placeholders
let
content_array
=
transformed_message
[
"content"
]
.as_array
()
.unwrap
();
assert_eq!
(
content_array
.len
(),
2
);
// Text part should remain unchanged
assert_eq!
(
content_array
[
0
][
"type"
],
"text"
);
assert_eq!
(
content_array
[
0
][
"text"
],
"Describe this image:"
);
// Image part should be replaced with simple type placeholder
assert_eq!
(
content_array
[
1
],
json!
({
"type"
:
"image"
}));
}
#[test]
fn
test_transform_messages_simple_string_content
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Simple text message"
.to_string
()),
name
:
None
,
}];
let
result
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
String
,
)
.unwrap
();
assert_eq!
(
result
.len
(),
1
);
let
transformed_message
=
&
result
[
0
];
// Simple string content should remain unchanged
assert_eq!
(
transformed_message
[
"content"
]
.as_str
()
.unwrap
(),
"Simple text message"
);
}
#[test]
fn
test_transform_messages_assistant_message
()
{
let
messages
=
vec!
[
ChatMessage
::
Assistant
{
role
:
"assistant"
.to_string
(),
content
:
Some
(
"Assistant response"
.to_string
()),
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
}];
let
result
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
String
,
)
.unwrap
();
assert_eq!
(
result
.len
(),
1
);
let
transformed_message
=
&
result
[
0
];
assert_eq!
(
transformed_message
[
"role"
]
.as_str
()
.unwrap
(),
"assistant"
);
assert_eq!
(
transformed_message
[
"content"
]
.as_str
()
.unwrap
(),
"Assistant response"
);
}
#[test]
fn
test_transform_messages_multiple_messages
()
{
let
messages
=
vec!
[
ChatMessage
::
System
{
role
:
"system"
.to_string
(),
content
:
"System prompt"
.to_string
(),
name
:
None
,
},
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
Text
{
text
:
"User message"
.to_string
(),
},
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
None
,
},
},
]),
name
:
None
,
},
];
let
result
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
String
,
)
.unwrap
();
assert_eq!
(
result
.len
(),
2
);
// System message should remain unchanged
assert_eq!
(
result
[
0
][
"role"
]
.as_str
()
.unwrap
(),
"system"
);
assert_eq!
(
result
[
0
][
"content"
]
.as_str
()
.unwrap
(),
"System prompt"
);
// User message should be flattened to text only
assert_eq!
(
result
[
1
][
"role"
]
.as_str
()
.unwrap
(),
"user"
);
assert_eq!
(
result
[
1
][
"content"
]
.as_str
()
.unwrap
(),
"User message"
);
}
#[test]
fn
test_transform_messages_empty_text_parts
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
None
,
},
}]),
name
:
None
,
}];
let
result
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
String
,
)
.unwrap
();
assert_eq!
(
result
.len
(),
1
);
let
transformed_message
=
&
result
[
0
];
// Should keep original multimodal content when no text parts exist
assert
!
(
transformed_message
[
"content"
]
.is_array
());
}
#[test]
fn
test_transform_messages_mixed_content_types
()
{
// Test with both text and multimodal content
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Plain text"
.to_string
()),
name
:
None
,
},
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
Text
{
text
:
"With image"
.to_string
(),
},
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
Some
(
"low"
.to_string
()),
},
},
]),
name
:
None
,
},
];
// Test String format
let
result_string
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
String
,
)
.unwrap
();
assert_eq!
(
result_string
.len
(),
2
);
assert_eq!
(
result_string
[
0
][
"content"
]
.as_str
()
.unwrap
(),
"Plain text"
);
assert_eq!
(
result_string
[
1
][
"content"
]
.as_str
()
.unwrap
(),
"With image"
);
// Test OpenAI format
let
result_openai
=
GrpcRouter
::
transform_messages_for_content_format
(
&
messages
,
ChatTemplateContentFormat
::
OpenAI
,
)
.unwrap
();
assert_eq!
(
result_openai
.len
(),
2
);
assert_eq!
(
result_openai
[
0
][
"content"
]
.as_str
()
.unwrap
(),
"Plain text"
);
let
content_array
=
result_openai
[
1
][
"content"
]
.as_array
()
.unwrap
();
assert_eq!
(
content_array
.len
(),
2
);
assert_eq!
(
content_array
[
0
][
"type"
],
"text"
);
assert_eq!
(
content_array
[
1
],
json!
({
"type"
:
"image"
}));
}
}
sgl-router/src/tokenizer/chat_template.rs
View file @
9209b209
...
@@ -4,39 +4,291 @@
...
@@ -4,39 +4,291 @@
//! similar to HuggingFace transformers' apply_chat_template method.
//! similar to HuggingFace transformers' apply_chat_template method.
use
anyhow
::{
anyhow
,
Result
};
use
anyhow
::{
anyhow
,
Result
};
use
minijinja
::{
context
,
Environment
,
Value
};
use
minijinja
::{
context
,
machinery
,
Environment
,
Value
};
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
;
use
serde_json
;
/// Represents a chat message with role and content
/// Chat template content format
#[derive(Debug,
Clone,
Serialize,
Deserialize)]
#[derive(Debug,
Clone,
Copy,
PartialEq,
Eq)]
pub
struct
ChatMessage
{
pub
enum
ChatTemplateContentFormat
{
pub
role
:
String
,
/// Content is a simple string
pub
content
:
String
,
String
,
/// Content is a list of structured parts (OpenAI format)
OpenAI
,
}
}
impl
ChatMessage
{
impl
Default
for
ChatTemplateContentFormat
{
pub
fn
new
(
role
:
impl
Into
<
String
>
,
content
:
impl
Into
<
String
>
)
->
Self
{
fn
default
()
->
Self
{
ChatMessage
{
Self
::
String
role
:
role
.into
(),
}
content
:
content
.into
(),
}
impl
std
::
fmt
::
Display
for
ChatTemplateContentFormat
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
match
self
{
Self
::
String
=>
write!
(
f
,
"string"
),
Self
::
OpenAI
=>
write!
(
f
,
"openai"
),
}
}
}
/// Detect the content format expected by a Jinja2 chat template
///
/// This implements the same detection logic as SGLang's detect_jinja_template_content_format
/// which uses AST parsing to look for content iteration patterns.
///
/// Returns:
/// - ChatTemplateContentFormat::OpenAI if template expects structured content (list of parts)
/// - ChatTemplateContentFormat::String if template expects simple string content
pub
fn
detect_chat_template_content_format
(
template
:
&
str
)
->
ChatTemplateContentFormat
{
// Use AST-based detection (enabled by default)
if
let
Some
(
format
)
=
detect_format_with_ast
(
template
)
{
return
format
;
}
// Default to string format if AST parsing fails
ChatTemplateContentFormat
::
String
}
/// AST-based detection using minijinja's unstable machinery
/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions
fn
detect_format_with_ast
(
template
:
&
str
)
->
Option
<
ChatTemplateContentFormat
>
{
use
minijinja
::
machinery
::{
parse
,
WhitespaceConfig
};
use
minijinja
::
syntax
::
SyntaxConfig
;
// Parse the template into AST
let
ast
=
match
parse
(
template
,
"template"
,
SyntaxConfig
{},
WhitespaceConfig
::
default
(),
)
{
Ok
(
ast
)
=>
ast
,
Err
(
_
)
=>
return
Some
(
ChatTemplateContentFormat
::
String
),
};
// Traverse AST looking for patterns that indicate OpenAI format
let
has_iteration
=
find_content_iteration_in_ast
(
&
ast
);
let
has_structure_checks
=
find_content_structure_checks_in_ast
(
&
ast
);
let
has_assignment_patterns
=
find_variable_assignment_patterns_in_ast
(
&
ast
);
if
has_iteration
||
has_structure_checks
||
has_assignment_patterns
{
Some
(
ChatTemplateContentFormat
::
OpenAI
)
}
else
{
Some
(
ChatTemplateContentFormat
::
String
)
}
}
/// Find content iteration patterns in AST
/// Implements the same logic as SGLang's AST traversal
fn
find_content_iteration_in_ast
(
ast
:
&
machinery
::
ast
::
Stmt
)
->
bool
{
use
machinery
::
ast
::
Stmt
;
match
ast
{
Stmt
::
Template
(
template
)
=>
{
// Recursively check all children
template
.children
.iter
()
.any
(|
child
|
find_content_iteration_in_ast
(
child
))
}
Stmt
::
ForLoop
(
for_loop
)
=>
{
// Check if this for-loop iterates over message content
is_var_or_elems_access
(
&
for_loop
.iter
,
"message"
,
"content"
)
||
is_var_or_elems_access
(
&
for_loop
.iter
,
"msg"
,
"content"
)
||
is_var_or_elems_access
(
&
for_loop
.iter
,
"m"
,
"content"
)
||
// Also check the body for nested loops
for_loop
.body
.iter
()
.any
(|
stmt
|
find_content_iteration_in_ast
(
stmt
))
}
Stmt
::
IfCond
(
if_cond
)
=>
{
// Check true and false branches
if_cond
.true_body
.iter
()
.any
(|
stmt
|
find_content_iteration_in_ast
(
stmt
))
||
if_cond
.false_body
.iter
()
.any
(|
stmt
|
find_content_iteration_in_ast
(
stmt
))
}
_
=>
false
,
// Other statement types don't contain loops
}
}
/// Check if expression accesses varname['key'] or varname.key
/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes
fn
is_var_or_elems_access
(
expr
:
&
machinery
::
ast
::
Expr
,
varname
:
&
str
,
key
:
&
str
)
->
bool
{
use
machinery
::
ast
::
Expr
;
match
expr
{
// Check for attribute access: varname.key
Expr
::
GetAttr
(
getattr
)
=>
is_var_access
(
&
getattr
.expr
,
varname
)
&&
getattr
.name
==
key
,
// Check for item access: varname['key'] or varname["key"]
Expr
::
GetItem
(
getitem
)
=>
{
is_var_access
(
&
getitem
.expr
,
varname
)
&&
is_const_string
(
&
getitem
.subscript_expr
,
key
)
}
// Handle filters and tests that might wrap the access
Expr
::
Filter
(
filter
)
=>
{
if
let
Some
(
ref
expr
)
=
filter
.expr
{
is_var_or_elems_access
(
expr
,
varname
,
key
)
}
else
{
false
}
}
Expr
::
Test
(
test
)
=>
is_var_or_elems_access
(
&
test
.expr
,
varname
,
key
),
_
=>
false
,
}
}
/// Check if expression is a variable access (like {{ varname }})
/// Implements SGLang's _is_var_access logic
fn
is_var_access
(
expr
:
&
machinery
::
ast
::
Expr
,
varname
:
&
str
)
->
bool
{
matches!
(
expr
,
machinery
::
ast
::
Expr
::
Var
(
var
)
if
var
.id
==
varname
)
}
/// Check if expression is a constant string with the given value
fn
is_const_string
(
expr
:
&
machinery
::
ast
::
Expr
,
value
:
&
str
)
->
bool
{
matches!
(
expr
,
machinery
::
ast
::
Expr
::
Const
(
const_expr
)
if
const_expr
.value
.as_str
()
==
Some
(
value
))
}
/// Find content structure checks in AST (like content[0], content|length)
fn
find_content_structure_checks_in_ast
(
ast
:
&
machinery
::
ast
::
Stmt
)
->
bool
{
use
machinery
::
ast
::
Stmt
;
match
ast
{
Stmt
::
Template
(
template
)
=>
template
.children
.iter
()
.any
(|
child
|
find_content_structure_checks_in_ast
(
child
)),
Stmt
::
ForLoop
(
for_loop
)
=>
for_loop
.body
.iter
()
.any
(|
stmt
|
find_content_structure_checks_in_ast
(
stmt
)),
Stmt
::
IfCond
(
if_cond
)
=>
{
// Check if condition has content structure checks
has_content_structure_check_expr
(
&
if_cond
.expr
)
||
if_cond
.true_body
.iter
()
.any
(|
stmt
|
find_content_structure_checks_in_ast
(
stmt
))
||
if_cond
.false_body
.iter
()
.any
(|
stmt
|
find_content_structure_checks_in_ast
(
stmt
))
}
Stmt
::
EmitExpr
(
expr
)
=>
has_content_structure_check_expr
(
&
expr
.expr
),
_
=>
false
,
}
}
/// Find variable assignment patterns like set content = message['content']
fn
find_variable_assignment_patterns_in_ast
(
ast
:
&
machinery
::
ast
::
Stmt
)
->
bool
{
use
machinery
::
ast
::
Stmt
;
match
ast
{
Stmt
::
Template
(
template
)
=>
template
.children
.iter
()
.any
(|
child
|
find_variable_assignment_patterns_in_ast
(
child
)),
Stmt
::
ForLoop
(
for_loop
)
=>
{
// Check if this for-loop body contains both assignment and iteration
let
has_assignment
=
for_loop
.body
.iter
()
.any
(|
stmt
|
is_content_assignment_stmt
(
stmt
));
let
has_iteration
=
for_loop
.body
.iter
()
.any
(|
stmt
|
{
is_content_variable_iteration
(
stmt
)
||
matches!
(
stmt
,
Stmt
::
IfCond
(
if_cond
)
if
if_cond
.true_body
.iter
()
.any
(|
s
|
is_content_variable_iteration
(
s
))
||
if_cond
.false_body
.iter
()
.any
(|
s
|
is_content_variable_iteration
(
s
))
)
});
(
has_assignment
&&
has_iteration
)
||
for_loop
.body
.iter
()
.any
(|
stmt
|
find_variable_assignment_patterns_in_ast
(
stmt
))
}
Stmt
::
IfCond
(
if_cond
)
=>
{
if_cond
.true_body
.iter
()
.any
(|
stmt
|
find_variable_assignment_patterns_in_ast
(
stmt
))
||
if_cond
.false_body
.iter
()
.any
(|
stmt
|
find_variable_assignment_patterns_in_ast
(
stmt
))
}
}
_
=>
false
,
}
}
}
/// Check if expression has content structure checks (index access, length, etc.)
fn
has_content_structure_check_expr
(
expr
:
&
machinery
::
ast
::
Expr
)
->
bool
{
use
machinery
::
ast
::
Expr
;
pub
fn
system
(
content
:
impl
Into
<
String
>
)
->
Self
{
match
expr
{
Self
::
new
(
"system"
,
content
)
// Check for content[0] - index access
Expr
::
GetItem
(
getitem
)
=>
{
is_content_access
(
&
getitem
.expr
)
&&
is_numeric_constant
(
&
getitem
.subscript_expr
)
}
// Check for content|length - filter with length
Expr
::
Filter
(
filter
)
=>
{
if
let
Some
(
ref
filter_expr
)
=
filter
.expr
{
is_content_access
(
filter_expr
)
&&
filter
.name
==
"length"
}
else
{
false
}
}
// Check for content is sequence/iterable
Expr
::
Test
(
test
)
=>
{
is_content_access
(
&
test
.expr
)
&&
(
test
.name
==
"sequence"
||
test
.name
==
"iterable"
)
}
_
=>
false
,
}
}
}
/// Check if statement assigns message content to a variable
fn
is_content_assignment_stmt
(
stmt
:
&
machinery
::
ast
::
Stmt
)
->
bool
{
use
machinery
::
ast
::
Stmt
;
pub
fn
user
(
content
:
impl
Into
<
String
>
)
->
Self
{
match
stmt
{
Self
::
new
(
"user"
,
content
)
Stmt
::
Set
(
set_stmt
)
=>
{
// Check if this is setting content = message['content']
is_var_access
(
&
set_stmt
.target
,
"content"
)
&&
is_var_or_elems_access
(
&
set_stmt
.expr
,
"message"
,
"content"
)
}
_
=>
false
,
}
}
}
pub
fn
assistant
(
content
:
impl
Into
<
String
>
)
->
Self
{
/// Check if statement iterates over content variable
Self
::
new
(
"assistant"
,
content
)
fn
is_content_variable_iteration
(
stmt
:
&
machinery
::
ast
::
Stmt
)
->
bool
{
use
machinery
::
ast
::{
Expr
,
Stmt
};
match
stmt
{
Stmt
::
ForLoop
(
for_loop
)
=>
{
// Check if iterating over a variable named "content"
matches!
(
for_loop
.iter
,
Expr
::
Var
(
ref
var
)
if
var
.id
==
"content"
)
}
_
=>
false
,
}
}
}
}
/// Chat template processor using Jinja2
/// Check if expression accesses content (message.content, message['content'], etc.)
fn
is_content_access
(
expr
:
&
machinery
::
ast
::
Expr
)
->
bool
{
is_var_or_elems_access
(
expr
,
"message"
,
"content"
)
||
is_var_or_elems_access
(
expr
,
"msg"
,
"content"
)
||
is_var_or_elems_access
(
expr
,
"m"
,
"content"
)
}
/// Check if expression is a numeric constant (for index access)
fn
is_numeric_constant
(
expr
:
&
machinery
::
ast
::
Expr
)
->
bool
{
matches!
(
expr
,
machinery
::
ast
::
Expr
::
Const
(
const_expr
)
if
const_expr
.value
.is_number
())
}
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace
pub
struct
ChatTemplateProcessor
{
pub
struct
ChatTemplateProcessor
{
template
:
String
,
template
:
String
,
bos_token
:
Option
<
String
>
,
bos_token
:
Option
<
String
>
,
...
@@ -57,9 +309,10 @@ impl ChatTemplateProcessor {
...
@@ -57,9 +309,10 @@ impl ChatTemplateProcessor {
///
///
/// This mimics the behavior of HuggingFace's apply_chat_template method
/// This mimics the behavior of HuggingFace's apply_chat_template method
/// but returns the formatted string instead of token IDs.
/// but returns the formatted string instead of token IDs.
/// Messages should be pre-processed into the format expected by the template.
pub
fn
apply_chat_template
(
pub
fn
apply_chat_template
(
&
self
,
&
self
,
messages
:
&
[
ChatMessag
e
],
messages
:
&
[
serde_json
::
Valu
e
],
add_generation_prompt
:
bool
,
add_generation_prompt
:
bool
,
)
->
Result
<
String
>
{
)
->
Result
<
String
>
{
let
mut
env
=
Environment
::
new
();
let
mut
env
=
Environment
::
new
();
...
@@ -73,21 +326,13 @@ impl ChatTemplateProcessor {
...
@@ -73,21 +326,13 @@ impl ChatTemplateProcessor {
.get_template
(
"chat"
)
.get_template
(
"chat"
)
.map_err
(|
e
|
anyhow!
(
"Failed to get template: {}"
,
e
))
?
;
.map_err
(|
e
|
anyhow!
(
"Failed to get template: {}"
,
e
))
?
;
// Convert messages to a format Jinja can work with
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump()
let
messages_value
:
Vec
<
Value
>
=
messages
let
minijinja_messages
:
Vec
<
Value
>
=
messages
.iter
()
.map
(
Value
::
from_serialize
)
.collect
();
.iter
()
.map
(|
msg
|
{
context!
{
role
=>
msg
.role
.clone
(),
content
=>
msg
.content
.clone
()
}
})
.collect
();
// Render the template
// Render the template
directly with the provided values
let
rendered
=
tmpl
let
rendered
=
tmpl
.render
(
context!
{
.render
(
context!
{
messages
=>
messages
_value
,
messages
=>
minijinja_
messages
,
add_generation_prompt
=>
add_generation_prompt
,
add_generation_prompt
=>
add_generation_prompt
,
bos_token
=>
self
.bos_token
.clone
()
.unwrap_or_default
(),
bos_token
=>
self
.bos_token
.clone
()
.unwrap_or_default
(),
eos_token
=>
self
.eos_token
.clone
()
.unwrap_or_default
()
eos_token
=>
self
.eos_token
.clone
()
.unwrap_or_default
()
...
@@ -114,69 +359,3 @@ pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String
...
@@ -114,69 +359,3 @@ pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String
Ok
(
None
)
Ok
(
None
)
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_chat_message_creation
()
{
let
msg
=
ChatMessage
::
system
(
"You are a helpful assistant"
);
assert_eq!
(
msg
.role
,
"system"
);
assert_eq!
(
msg
.content
,
"You are a helpful assistant"
);
let
user_msg
=
ChatMessage
::
user
(
"Hello!"
);
assert_eq!
(
user_msg
.role
,
"user"
);
let
assistant_msg
=
ChatMessage
::
assistant
(
"Hi there!"
);
assert_eq!
(
assistant_msg
.role
,
"assistant"
);
}
#[test]
fn
test_simple_chat_template
()
{
// Simple template that formats messages
let
template
=
r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
vec!
[
ChatMessage
::
system
(
"You are helpful"
),
ChatMessage
::
user
(
"Hello"
),
];
let
result
=
processor
.apply_chat_template
(
&
messages
,
true
)
.unwrap
();
assert
!
(
result
.contains
(
"system: You are helpful"
));
assert
!
(
result
.contains
(
"user: Hello"
));
assert
!
(
result
.contains
(
"assistant:"
));
}
#[test]
fn
test_chat_template_with_tokens
()
{
// Template that uses special tokens
let
template
=
r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<s>"
.to_string
()),
Some
(
"</s>"
.to_string
()),
);
let
messages
=
vec!
[
ChatMessage
::
user
(
"Test"
)];
let
result
=
processor
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
assert
!
(
result
.contains
(
"<s>"
));
assert
!
(
result
.contains
(
"</s>"
));
}
}
sgl-router/src/tokenizer/huggingface.rs
View file @
9209b209
use
super
::
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
TokenIdType
,
Tokenizer
as
TokenizerTrait
,
};
use
anyhow
::{
Error
,
Result
};
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
anyhow
::{
Error
,
Result
};
use
tokenizers
::
tokenizer
::
Tokenizer
as
HfTokenizer
;
use
tokenizers
::
tokenizer
::
Tokenizer
as
HfTokenizer
;
use
super
::
chat_template
::{
ChatMessage
,
ChatTemplateProcessor
};
use
super
::
chat_template
::{
detect_chat_template_content_format
,
ChatTemplateContentFormat
,
ChatTemplateProcessor
,
};
use
super
::
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
TokenIdType
,
Tokenizer
as
TokenizerTrait
,
};
/// HuggingFace tokenizer wrapper
/// HuggingFace tokenizer wrapper
pub
struct
HuggingFaceTokenizer
{
pub
struct
HuggingFaceTokenizer
{
...
@@ -14,6 +17,8 @@ pub struct HuggingFaceTokenizer {
...
@@ -14,6 +17,8 @@ pub struct HuggingFaceTokenizer {
vocab
:
HashMap
<
String
,
TokenIdType
>
,
vocab
:
HashMap
<
String
,
TokenIdType
>
,
reverse_vocab
:
HashMap
<
TokenIdType
,
String
>
,
reverse_vocab
:
HashMap
<
TokenIdType
,
String
>
,
chat_template
:
Option
<
String
>
,
chat_template
:
Option
<
String
>
,
/// Detected chat template content format (computed once at initialization)
content_format
:
ChatTemplateContentFormat
,
}
}
impl
HuggingFaceTokenizer
{
impl
HuggingFaceTokenizer
{
...
@@ -49,12 +54,20 @@ impl HuggingFaceTokenizer {
...
@@ -49,12 +54,20 @@ impl HuggingFaceTokenizer {
Self
::
load_chat_template
(
file_path
)
Self
::
load_chat_template
(
file_path
)
};
};
// Detect content format once at initialization
let
content_format
=
if
let
Some
(
ref
template
)
=
chat_template
{
detect_chat_template_content_format
(
template
)
}
else
{
ChatTemplateContentFormat
::
String
// Default if no template
};
Ok
(
HuggingFaceTokenizer
{
Ok
(
HuggingFaceTokenizer
{
tokenizer
,
tokenizer
,
special_tokens
,
special_tokens
,
vocab
,
vocab
,
reverse_vocab
,
reverse_vocab
,
chat_template
,
chat_template
,
content_format
,
})
})
}
}
...
@@ -73,6 +86,7 @@ impl HuggingFaceTokenizer {
...
@@ -73,6 +86,7 @@ impl HuggingFaceTokenizer {
vocab
,
vocab
,
reverse_vocab
,
reverse_vocab
,
chat_template
:
None
,
chat_template
:
None
,
content_format
:
ChatTemplateContentFormat
::
String
,
// Default
}
}
}
}
...
@@ -135,13 +149,22 @@ impl HuggingFaceTokenizer {
...
@@ -135,13 +149,22 @@ impl HuggingFaceTokenizer {
/// Set or override the chat template
/// Set or override the chat template
pub
fn
set_chat_template
(
&
mut
self
,
template
:
String
)
{
pub
fn
set_chat_template
(
&
mut
self
,
template
:
String
)
{
// Detect format for the new template
self
.content_format
=
detect_chat_template_content_format
(
&
template
);
self
.chat_template
=
Some
(
template
);
self
.chat_template
=
Some
(
template
);
}
}
/// Get the content format expected by the chat template
pub
fn
chat_template_content_format
(
&
self
)
->
ChatTemplateContentFormat
{
self
.content_format
}
/// Apply chat template if available
/// Apply chat template if available
///
/// Takes transformed JSON Values (already transformed based on content format)
pub
fn
apply_chat_template
(
pub
fn
apply_chat_template
(
&
self
,
&
self
,
messages
:
&
[
ChatMessag
e
],
messages
:
&
[
serde_json
::
Valu
e
],
add_generation_prompt
:
bool
,
add_generation_prompt
:
bool
,
)
->
Result
<
String
>
{
)
->
Result
<
String
>
{
if
let
Some
(
ref
template
)
=
self
.chat_template
{
if
let
Some
(
ref
template
)
=
self
.chat_template
{
...
@@ -150,17 +173,15 @@ impl HuggingFaceTokenizer {
...
@@ -150,17 +173,15 @@ impl HuggingFaceTokenizer {
self
.special_tokens.bos_token
.clone
(),
self
.special_tokens.bos_token
.clone
(),
self
.special_tokens.eos_token
.clone
(),
self
.special_tokens.eos_token
.clone
(),
);
);
processor
.apply_chat_template
(
messages
,
add_generation_prompt
)
processor
.apply_chat_template
(
messages
,
add_generation_prompt
)
}
else
{
}
else
{
// Fallback to simple formatting if no template is available
Err
(
Error
::
msg
(
let
mut
result
=
String
::
new
();
"Cannot use chat template functions because tokenizer.chat_template is not set and no template
\
for
msg
in
messages
{
argument was passed! For information about writing templates and setting the
\
result
.push_str
(
&
format!
(
"{}: {}
\n
"
,
msg
.role
,
msg
.content
));
tokenizer.chat_template attribute, please see the documentation at
\
}
https://huggingface.co/docs/transformers/main/en/chat_templating"
if
add_generation_prompt
{
))
result
.push_str
(
"assistant: "
);
}
Ok
(
result
)
}
}
}
}
}
}
...
@@ -218,21 +239,6 @@ impl TokenizerTrait for HuggingFaceTokenizer {
...
@@ -218,21 +239,6 @@ impl TokenizerTrait for HuggingFaceTokenizer {
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
super
::
ChatMessage
;
#[test]
fn
test_chat_message_creation
()
{
let
msg
=
ChatMessage
::
system
(
"You are a helpful assistant"
);
assert_eq!
(
msg
.role
,
"system"
);
assert_eq!
(
msg
.content
,
"You are a helpful assistant"
);
let
user_msg
=
ChatMessage
::
user
(
"Hello!"
);
assert_eq!
(
user_msg
.role
,
"user"
);
let
assistant_msg
=
ChatMessage
::
assistant
(
"Hi there!"
);
assert_eq!
(
assistant_msg
.role
,
"assistant"
);
}
// Note: Actual tokenizer tests would require a real tokenizer file
// Note: Actual tokenizer tests would require a real tokenizer file
// These would be integration tests rather than unit tests
// These would be integration tests rather than unit tests
}
}
sgl-router/src/tokenizer/mod.rs
View file @
9209b209
...
@@ -33,8 +33,6 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
...
@@ -33,8 +33,6 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
pub
use
huggingface
::
HuggingFaceTokenizer
;
pub
use
huggingface
::
HuggingFaceTokenizer
;
pub
use
chat_template
::
ChatMessage
;
pub
use
tiktoken
::{
TiktokenModel
,
TiktokenTokenizer
};
pub
use
tiktoken
::{
TiktokenModel
,
TiktokenTokenizer
};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
...
...
sgl-router/tests/chat_template_format_detection.rs
0 → 100644
View file @
9209b209
use
sglang_router_rs
::
protocols
::
spec
;
use
sglang_router_rs
::
tokenizer
::
chat_template
::{
detect_chat_template_content_format
,
ChatTemplateContentFormat
,
ChatTemplateProcessor
,
};
#[test]
fn
test_detect_string_format_deepseek
()
{
// DeepSeek style template - expects string content
let
template
=
r#"
{%- for message in messages %}
{%- if message['role'] == 'user' %}
User: {{ message['content'] }}
{%- elif message['role'] == 'assistant' %}
Assistant: {{ message['content'] }}
{%- endif %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
String
);
}
#[test]
fn
test_detect_openai_format_llama4
()
{
// Llama4 style template - expects structured content
let
template
=
r#"
{%- for message in messages %}
{%- if message['content'] is iterable %}
{%- for content in message['content'] %}
{%- if content['type'] == 'text' %}
{{ content['text'] }}
{%- elif content['type'] == 'image' %}
<image>
{%- endif %}
{%- endfor %}
{%- else %}
{{ message['content'] }}
{%- endif %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_detect_openai_format_dot_notation
()
{
// Template using dot notation
let
template
=
r#"
{%- for message in messages %}
{%- for part in message.content %}
{%- if part.type == 'text' %}
{{ part.text }}
{%- endif %}
{%- endfor %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_detect_openai_format_variable_assignment
()
{
// Template that assigns content to variable then iterates
let
template
=
r#"
{%- for message in messages %}
{%- set content = message['content'] %}
{%- if content is sequence %}
{%- for item in content %}
{{ item }}
{%- endfor %}
{%- endif %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_detect_openai_format_glm4v_style
()
{
// GLM4V uses 'msg' instead of 'message'
let
template
=
r#"
{%- for msg in messages %}
{%- for part in msg.content %}
{%- if part.type == 'text' %}{{ part.text }}{%- endif %}
{%- if part.type == 'image' %}<image>{%- endif %}
{%- endfor %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_detect_openai_format_with_length_check
()
{
// Template that checks content length
let
template
=
r#"
{%- for message in messages %}
{%- if message.content|length > 0 %}
{%- for item in message.content %}
{{ item.text }}
{%- endfor %}
{%- endif %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_detect_openai_format_with_index_access
()
{
// Template that accesses content by index
let
template
=
r#"
{%- for message in messages %}
{%- if message.content[0] %}
First item: {{ message.content[0].text }}
{%- endif %}
{%- endfor %}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_invalid_template_defaults_to_string
()
{
let
template
=
"Not a valid {% jinja template"
;
assert_eq!
(
detect_chat_template_content_format
(
template
),
ChatTemplateContentFormat
::
String
);
}
#[test]
fn
test_empty_template_defaults_to_string
()
{
assert_eq!
(
detect_chat_template_content_format
(
""
),
ChatTemplateContentFormat
::
String
);
}
#[test]
fn
test_simple_chat_template_unit_test
()
{
let
template
=
r#"
{%- for message in messages %}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt %}
assistant:
{%- endif %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<s>"
.to_string
()),
Some
(
"</s>"
.to_string
()),
);
let
messages
=
vec!
[
spec
::
ChatMessage
::
System
{
role
:
"system"
.to_string
(),
content
:
"You are helpful"
.to_string
(),
name
:
None
,
},
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Hello"
.to_string
()),
name
:
None
,
},
];
// Convert to JSON values like the router does
let
message_values
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
message_values
,
true
)
.unwrap
();
assert
!
(
result
.contains
(
"system: You are helpful"
));
assert
!
(
result
.contains
(
"user: Hello"
));
assert
!
(
result
.contains
(
"assistant:"
));
}
#[test]
fn
test_chat_template_with_tokens_unit_test
()
{
// Template that uses special tokens
let
template
=
r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<s>"
.to_string
()),
Some
(
"</s>"
.to_string
()),
);
let
messages
=
[
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
// Convert to JSON values like the router does
let
message_values
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
message_values
,
false
)
.unwrap
();
assert
!
(
result
.contains
(
"<s>"
));
assert
!
(
result
.contains
(
"</s>"
));
}
sgl-router/tests/chat_template_integration.rs
0 → 100644
View file @
9209b209
use
sglang_router_rs
::
protocols
::
spec
;
use
sglang_router_rs
::
tokenizer
::
chat_template
::{
detect_chat_template_content_format
,
ChatTemplateContentFormat
,
ChatTemplateProcessor
,
};
#[test]
fn
test_simple_chat_template
()
{
let
template
=
r#"
{%- for message in messages %}
<|{{ message.role }}|>{{ message.content }}<|end|>
{% endfor -%}
{%- if add_generation_prompt %}
<|assistant|>
{%- endif %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<s>"
.to_string
()),
Some
(
"</s>"
.to_string
()),
);
let
messages
=
[
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
// Convert to JSON values like the router does
let
message_values
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
message_values
,
true
)
.unwrap
();
assert
!
(
result
.contains
(
"<|user|>Test<|end|>"
));
assert
!
(
result
.contains
(
"<|assistant|>"
));
}
#[test]
fn
test_chat_template_with_tokens
()
{
// Template that uses special tokens
let
template
=
r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<s>"
.to_string
()),
Some
(
"</s>"
.to_string
()),
);
let
messages
=
[
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
// Convert to JSON values like the router does
let
message_values
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
message_values
,
false
)
.unwrap
();
assert
!
(
result
.contains
(
"<s>"
));
assert
!
(
result
.contains
(
"</s>"
));
}
#[test]
fn
test_llama_style_template
()
{
// Test a Llama-style chat template
let
template
=
r#"
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{%- set system_message = '' -%}
{%- endif -%}
{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<|begin_of_text|>"
.to_string
()),
Some
(
"<|end_of_text|>"
.to_string
()),
);
let
messages
=
vec!
[
spec
::
ChatMessage
::
System
{
role
:
"system"
.to_string
(),
content
:
"You are a helpful assistant"
.to_string
(),
name
:
None
,
},
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"What is 2+2?"
.to_string
()),
name
:
None
,
},
];
// Convert to JSON values
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
json_messages
,
true
)
.unwrap
();
// Check that the result contains expected markers
assert
!
(
result
.contains
(
"<|begin_of_text|>"
));
assert
!
(
result
.contains
(
"<|start_header_id|>system<|end_header_id|>"
));
assert
!
(
result
.contains
(
"You are a helpful assistant"
));
assert
!
(
result
.contains
(
"<|start_header_id|>user<|end_header_id|>"
));
assert
!
(
result
.contains
(
"What is 2+2?"
));
assert
!
(
result
.contains
(
"<|start_header_id|>assistant<|end_header_id|>"
));
}
#[test]
fn
test_chatml_template
()
{
// Test a ChatML-style template
let
template
=
r#"
{%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
vec!
[
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Hello"
.to_string
()),
name
:
None
,
},
spec
::
ChatMessage
::
Assistant
{
role
:
"assistant"
.to_string
(),
content
:
Some
(
"Hi there!"
.to_string
()),
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
},
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"How are you?"
.to_string
()),
name
:
None
,
},
];
// Convert to JSON values
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
json_messages
,
true
)
.unwrap
();
// Check ChatML format
assert
!
(
result
.contains
(
"<|im_start|>user
\n
Hello<|im_end|>"
));
assert
!
(
result
.contains
(
"<|im_start|>assistant
\n
Hi there!<|im_end|>"
));
assert
!
(
result
.contains
(
"<|im_start|>user
\n
How are you?<|im_end|>"
));
assert
!
(
result
.ends_with
(
"<|im_start|>assistant
\n
"
));
}
#[test]
fn
test_template_without_generation_prompt
()
{
let
template
=
r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
[
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
// Convert to JSON values
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
// Test without generation prompt
let
result
=
processor
.apply_chat_template
(
&
json_messages
,
false
)
.unwrap
();
assert_eq!
(
result
.trim
(),
"user: Test"
);
// Test with generation prompt
let
result_with_prompt
=
processor
.apply_chat_template
(
&
json_messages
,
true
)
.unwrap
();
assert
!
(
result_with_prompt
.contains
(
"assistant:"
));
}
#[test]
fn
test_empty_messages_template
()
{
let
template
=
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
:
Vec
<
serde_json
::
Value
>
=
vec!
[];
let
result
=
processor
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
assert_eq!
(
result
,
""
);
}
#[test]
fn
test_content_format_detection
()
{
// Test string format detection
let
string_template
=
r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{%- endfor -%}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
string_template
),
ChatTemplateContentFormat
::
String
);
// Test OpenAI format detection
let
openai_template
=
r#"
{%- for message in messages -%}
{%- for content in message.content -%}
{{ content.type }}: {{ content.text }}
{%- endfor -%}
{%- endfor -%}
"#
;
assert_eq!
(
detect_chat_template_content_format
(
openai_template
),
ChatTemplateContentFormat
::
OpenAI
);
}
#[test]
fn
test_template_with_multimodal_content
()
{
// Test that multimodal messages work correctly when serialized to JSON
let
template
=
r#"
{%- for message in messages %}
{{ message.role }}:
{%- if message.content is string %}
{{ message.content }}
{%- else %}
{%- for part in message.content %}
{%- if part.type == "text" %}
{{ part.text }}
{%- elif part.type == "image_url" %}
[IMAGE]
{%- endif %}
{%- endfor %}
{%- endif %}
{% endfor %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
[
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Parts
(
vec!
[
spec
::
ContentPart
::
Text
{
text
:
"Look at this:"
.to_string
(),
},
spec
::
ContentPart
::
ImageUrl
{
image_url
:
spec
::
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
None
,
},
},
]),
name
:
None
,
}];
// Convert to JSON values
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
processor
.apply_chat_template
(
&
json_messages
,
false
)
.unwrap
();
// Should contain both text and image parts
assert
!
(
result
.contains
(
"user:"
));
assert
!
(
result
.contains
(
"Look at this:"
));
assert
!
(
result
.contains
(
"[IMAGE]"
));
}
sgl-router/tests/
test_
chat_template_loading.rs
→
sgl-router/tests/chat_template_loading.rs
View file @
9209b209
#[cfg(test)]
#[cfg(test)]
mod
tests
{
mod
tests
{
use
sglang_router_rs
::
protocols
::
spec
;
use
sglang_router_rs
::
tokenizer
::
huggingface
::
HuggingFaceTokenizer
;
use
std
::
fs
;
use
std
::
fs
;
use
tempfile
::
TempDir
;
use
tempfile
::
TempDir
;
#[test]
#[test]
fn
test_load_chat_template_from_file
()
{
fn
test_load_chat_template_from_file
()
{
use
sglang_router_rs
::
tokenizer
::
chat_template
::
ChatMessage
;
use
sglang_router_rs
::
tokenizer
::
huggingface
::
HuggingFaceTokenizer
;
// Create temporary directory
// Create temporary directory
let
temp_dir
=
TempDir
::
new
()
.unwrap
();
let
temp_dir
=
TempDir
::
new
()
.unwrap
();
let
template_path
=
temp_dir
.path
()
.join
(
"template.jinja"
);
let
template_path
=
temp_dir
.path
()
.join
(
"template.jinja"
);
...
@@ -59,11 +58,28 @@ mod tests {
...
@@ -59,11 +58,28 @@ mod tests {
// Test that the custom template is used
// Test that the custom template is used
let
messages
=
vec!
[
let
messages
=
vec!
[
ChatMessage
::
user
(
"Hello"
),
spec
::
ChatMessage
::
User
{
ChatMessage
::
assistant
(
"Hi there"
),
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Hello"
.to_string
()),
name
:
None
,
},
spec
::
ChatMessage
::
Assistant
{
role
:
"assistant"
.to_string
(),
content
:
Some
(
"Hi there"
.to_string
()),
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
},
];
];
let
result
=
tokenizer
.apply_chat_template
(
&
messages
,
true
)
.unwrap
();
// Convert to JSON values like the router does
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
tokenizer
.apply_chat_template
(
&
json_messages
,
true
)
.unwrap
();
// Verify the custom template format
// Verify the custom template format
assert
!
(
result
.contains
(
"<|user|>Hello"
));
assert
!
(
result
.contains
(
"<|user|>Hello"
));
...
@@ -73,9 +89,6 @@ mod tests {
...
@@ -73,9 +89,6 @@ mod tests {
#[test]
#[test]
fn
test_override_existing_template
()
{
fn
test_override_existing_template
()
{
use
sglang_router_rs
::
tokenizer
::
chat_template
::
ChatMessage
;
use
sglang_router_rs
::
tokenizer
::
huggingface
::
HuggingFaceTokenizer
;
// Create temporary directory
// Create temporary directory
let
temp_dir
=
TempDir
::
new
()
.unwrap
();
let
temp_dir
=
TempDir
::
new
()
.unwrap
();
...
@@ -124,8 +137,21 @@ mod tests {
...
@@ -124,8 +137,21 @@ mod tests {
)
)
.unwrap
();
.unwrap
();
let
messages
=
vec!
[
ChatMessage
::
user
(
"Test"
)];
let
messages
=
[
spec
::
ChatMessage
::
User
{
let
result
=
tokenizer
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
// Convert to JSON values
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
tokenizer
.apply_chat_template
(
&
json_messages
,
false
)
.unwrap
();
// Should use CUSTOM template, not built-in
// Should use CUSTOM template, not built-in
assert
!
(
result
.starts_with
(
"CUSTOM:"
));
assert
!
(
result
.starts_with
(
"CUSTOM:"
));
...
@@ -135,9 +161,6 @@ mod tests {
...
@@ -135,9 +161,6 @@ mod tests {
#[test]
#[test]
fn
test_set_chat_template_after_creation
()
{
fn
test_set_chat_template_after_creation
()
{
use
sglang_router_rs
::
tokenizer
::
chat_template
::
ChatMessage
;
use
sglang_router_rs
::
tokenizer
::
huggingface
::
HuggingFaceTokenizer
;
// Create temporary directory and tokenizer file
// Create temporary directory and tokenizer file
let
temp_dir
=
TempDir
::
new
()
.unwrap
();
let
temp_dir
=
TempDir
::
new
()
.unwrap
();
let
tokenizer_json
=
r#"{
let
tokenizer_json
=
r#"{
...
@@ -173,8 +196,31 @@ mod tests {
...
@@ -173,8 +196,31 @@ mod tests {
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}"
;
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}"
;
tokenizer
.set_chat_template
(
new_template
.to_string
());
tokenizer
.set_chat_template
(
new_template
.to_string
());
let
messages
=
vec!
[
ChatMessage
::
user
(
"Hello"
),
ChatMessage
::
assistant
(
"World"
)];
let
messages
=
vec!
[
let
result
=
tokenizer
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
spec
::
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
spec
::
UserMessageContent
::
Text
(
"Hello"
.to_string
()),
name
:
None
,
},
spec
::
ChatMessage
::
Assistant
{
role
:
"assistant"
.to_string
(),
content
:
Some
(
"World"
.to_string
()),
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
reasoning_content
:
None
,
},
];
// Convert to JSON values
let
json_messages
:
Vec
<
serde_json
::
Value
>
=
messages
.iter
()
.map
(|
msg
|
serde_json
::
to_value
(
msg
)
.unwrap
())
.collect
();
let
result
=
tokenizer
.apply_chat_template
(
&
json_messages
,
false
)
.unwrap
();
assert
!
(
result
.starts_with
(
"NEW:"
));
assert
!
(
result
.starts_with
(
"NEW:"
));
assert
!
(
result
.contains
(
"user: Hello;"
));
assert
!
(
result
.contains
(
"user: Hello;"
));
...
...
sgl-router/tests/test_chat_template.rs
deleted
100644 → 0
View file @
adba172f
#[cfg(test)]
mod
tests
{
use
sglang_router_rs
::
tokenizer
::
chat_template
::{
ChatMessage
,
ChatTemplateProcessor
};
#[test]
fn
test_chat_message_helpers
()
{
let
system_msg
=
ChatMessage
::
system
(
"You are a helpful assistant"
);
assert_eq!
(
system_msg
.role
,
"system"
);
assert_eq!
(
system_msg
.content
,
"You are a helpful assistant"
);
let
user_msg
=
ChatMessage
::
user
(
"Hello!"
);
assert_eq!
(
user_msg
.role
,
"user"
);
assert_eq!
(
user_msg
.content
,
"Hello!"
);
let
assistant_msg
=
ChatMessage
::
assistant
(
"Hi there!"
);
assert_eq!
(
assistant_msg
.role
,
"assistant"
);
assert_eq!
(
assistant_msg
.content
,
"Hi there!"
);
}
#[test]
fn
test_llama_style_template
()
{
// Test a Llama-style chat template
let
template
=
r#"
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{%- set system_message = '' -%}
{%- endif -%}
{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<|begin_of_text|>"
.to_string
()),
Some
(
"<|end_of_text|>"
.to_string
()),
);
let
messages
=
vec!
[
ChatMessage
::
system
(
"You are a helpful assistant"
),
ChatMessage
::
user
(
"What is 2+2?"
),
];
let
result
=
processor
.apply_chat_template
(
&
messages
,
true
)
.unwrap
();
// Check that the result contains expected markers
assert
!
(
result
.contains
(
"<|begin_of_text|>"
));
assert
!
(
result
.contains
(
"<|start_header_id|>system<|end_header_id|>"
));
assert
!
(
result
.contains
(
"You are a helpful assistant"
));
assert
!
(
result
.contains
(
"<|start_header_id|>user<|end_header_id|>"
));
assert
!
(
result
.contains
(
"What is 2+2?"
));
assert
!
(
result
.contains
(
"<|start_header_id|>assistant<|end_header_id|>"
));
}
#[test]
fn
test_chatml_template
()
{
// Test a ChatML-style template
let
template
=
r#"
{%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
vec!
[
ChatMessage
::
user
(
"Hello"
),
ChatMessage
::
assistant
(
"Hi there!"
),
ChatMessage
::
user
(
"How are you?"
),
];
let
result
=
processor
.apply_chat_template
(
&
messages
,
true
)
.unwrap
();
// Check ChatML format
assert
!
(
result
.contains
(
"<|im_start|>user
\n
Hello<|im_end|>"
));
assert
!
(
result
.contains
(
"<|im_start|>assistant
\n
Hi there!<|im_end|>"
));
assert
!
(
result
.contains
(
"<|im_start|>user
\n
How are you?<|im_end|>"
));
assert
!
(
result
.ends_with
(
"<|im_start|>assistant
\n
"
));
}
#[test]
fn
test_template_without_generation_prompt
()
{
let
template
=
r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
vec!
[
ChatMessage
::
user
(
"Test"
)];
// Test without generation prompt
let
result
=
processor
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
assert_eq!
(
result
.trim
(),
"user: Test"
);
// Test with generation prompt
let
result_with_prompt
=
processor
.apply_chat_template
(
&
messages
,
true
)
.unwrap
();
assert
!
(
result_with_prompt
.contains
(
"assistant:"
));
}
#[test]
fn
test_template_with_special_tokens
()
{
let
template
=
r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
Some
(
"<s>"
.to_string
()),
Some
(
"</s>"
.to_string
()),
);
let
messages
=
vec!
[
ChatMessage
::
user
(
"Hello"
)];
let
result
=
processor
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
assert_eq!
(
result
,
"<s>Hello</s>"
);
}
#[test]
fn
test_empty_messages
()
{
let
template
=
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#
;
let
processor
=
ChatTemplateProcessor
::
new
(
template
.to_string
(),
None
,
None
);
let
messages
=
vec!
[];
let
result
=
processor
.apply_chat_template
(
&
messages
,
false
)
.unwrap
();
assert_eq!
(
result
,
""
);
}
// Integration test with actual tokenizer file loading would go here
// but requires a real tokenizer_config.json file
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment