Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1fc455e8
Unverified
Commit
1fc455e8
authored
Jul 20, 2025
by
Simo Lin
Committed by
GitHub
Jul 20, 2025
Browse files
[router] add ut for pd request, metrics and config (#8184)
parent
465968b2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
2003 additions
and
72 deletions
+2003
-72
sgl-router/src/config/types.rs
sgl-router/src/config/types.rs
+578
-71
sgl-router/src/metrics.rs
sgl-router/src/metrics.rs
+411
-0
sgl-router/src/routers/pd_types.rs
sgl-router/src/routers/pd_types.rs
+1
-1
sgl-router/src/routers/request_adapter.rs
sgl-router/src/routers/request_adapter.rs
+1013
-0
No files found.
sgl-router/src/config/types.rs
View file @
1fc455e8
...
@@ -214,83 +214,590 @@ impl RouterConfig {
...
@@ -214,83 +214,590 @@ impl RouterConfig {
pub
fn
has_metrics
(
&
self
)
->
bool
{
pub
fn
has_metrics
(
&
self
)
->
bool
{
self
.metrics
.is_some
()
self
.metrics
.is_some
()
}
}
}
/* Commented out - no longer needed without compatibility layer
#[cfg(test)]
/// Convert to routing PolicyConfig for internal use
mod
tests
{
pub fn to_routing_policy_config(&self) -> ConfigResult<crate::router::PolicyConfig> {
use
super
::
*
;
match (&self.mode, &self.policy) {
(
// ============= RouterConfig Tests =============
RoutingMode::PrefillDecode {
prefill_urls,
#[test]
decode_urls,
fn
test_router_config_default
()
{
let
config
=
RouterConfig
::
default
();
assert
!
(
matches!
(
config
.mode
,
RoutingMode
::
Regular
{
worker_urls
}
if
worker_urls
.is_empty
())
);
assert
!
(
matches!
(
config
.policy
,
PolicyConfig
::
Random
));
assert_eq!
(
config
.host
,
"127.0.0.1"
);
assert_eq!
(
config
.port
,
3001
);
assert_eq!
(
config
.max_payload_size
,
268_435_456
);
assert_eq!
(
config
.request_timeout_secs
,
600
);
assert_eq!
(
config
.worker_startup_timeout_secs
,
300
);
assert_eq!
(
config
.worker_startup_check_interval_secs
,
10
);
assert
!
(
config
.discovery
.is_none
());
assert
!
(
config
.metrics
.is_none
());
assert
!
(
config
.log_dir
.is_none
());
assert
!
(
config
.log_level
.is_none
());
}
#[test]
fn
test_router_config_new
()
{
let
mode
=
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
(),
"http://worker2"
.to_string
()],
};
let
policy
=
PolicyConfig
::
RoundRobin
;
let
config
=
RouterConfig
::
new
(
mode
,
policy
);
match
config
.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
assert_eq!
(
worker_urls
.len
(),
2
);
assert_eq!
(
worker_urls
[
0
],
"http://worker1"
);
assert_eq!
(
worker_urls
[
1
],
"http://worker2"
);
}
_
=>
panic!
(
"Expected Regular mode"
),
}
assert
!
(
matches!
(
config
.policy
,
PolicyConfig
::
RoundRobin
));
// Other fields should be default
assert_eq!
(
config
.host
,
"127.0.0.1"
);
assert_eq!
(
config
.port
,
3001
);
}
#[test]
fn
test_router_config_serialization
()
{
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
()],
},
},
policy,
policy
:
PolicyConfig
::
Random
,
) => {
host
:
"0.0.0.0"
.to_string
(),
// Map policy to PDSelectionPolicy
port
:
8080
,
let selection_policy = match policy {
max_payload_size
:
1024
,
PolicyConfig::Random => crate::pd_types::PDSelectionPolicy::Random,
request_timeout_secs
:
30
,
PolicyConfig::PowerOfTwo { .. } => {
worker_startup_timeout_secs
:
60
,
crate::pd_types::PDSelectionPolicy::PowerOfTwo
worker_startup_check_interval_secs
:
5
,
}
discovery
:
Some
(
DiscoveryConfig
::
default
()),
PolicyConfig::CacheAware { .. } => {
metrics
:
Some
(
MetricsConfig
::
default
()),
return Err(ConfigError::IncompatibleConfig {
log_dir
:
Some
(
"/var/log"
.to_string
()),
reason: "CacheAware policy is not supported in PD disaggregated mode"
log_level
:
Some
(
"debug"
.to_string
()),
.to_string(),
};
});
}
let
json
=
serde_json
::
to_string
(
&
config
)
.unwrap
();
PolicyConfig::RoundRobin => {
let
deserialized
:
RouterConfig
=
serde_json
::
from_str
(
&
json
)
.unwrap
();
return Err(ConfigError::IncompatibleConfig {
reason: "RoundRobin policy is not supported in PD disaggregated mode"
assert_eq!
(
config
.host
,
deserialized
.host
);
.to_string(),
assert_eq!
(
config
.port
,
deserialized
.port
);
});
assert_eq!
(
config
.max_payload_size
,
deserialized
.max_payload_size
);
assert
!
(
deserialized
.discovery
.is_some
());
assert
!
(
deserialized
.metrics
.is_some
());
}
// ============= RoutingMode Tests =============
#[test]
fn
test_routing_mode_is_pd_mode
()
{
let
regular
=
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
()],
};
assert
!
(
!
regular
.is_pd_mode
());
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
Some
(
8001
))],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
};
assert
!
(
pd
.is_pd_mode
());
}
#[test]
fn
test_routing_mode_worker_count
()
{
let
regular
=
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
(),
"http://worker2"
.to_string
(),
"http://worker3"
.to_string
(),
],
};
assert_eq!
(
regular
.worker_count
(),
3
);
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[
(
"http://prefill1"
.to_string
(),
Some
(
8001
)),
(
"http://prefill2"
.to_string
(),
None
),
],
decode_urls
:
vec!
[
"http://decode1"
.to_string
(),
"http://decode2"
.to_string
(),
"http://decode3"
.to_string
(),
],
};
assert_eq!
(
pd
.worker_count
(),
5
);
let
empty_regular
=
RoutingMode
::
Regular
{
worker_urls
:
vec!
[],
};
assert_eq!
(
empty_regular
.worker_count
(),
0
);
}
#[test]
fn
test_routing_mode_serialization
()
{
// Test Regular mode
let
regular
=
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
()],
};
let
json
=
serde_json
::
to_string
(
&
regular
)
.unwrap
();
assert
!
(
json
.contains
(
"
\"
type
\"
:
\"
regular
\"
"
));
assert
!
(
json
.contains
(
"
\"
worker_urls
\"
"
));
// Test PrefillDecode mode
let
pd
=
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[(
"http://prefill1"
.to_string
(),
Some
(
8001
))],
decode_urls
:
vec!
[
"http://decode1"
.to_string
()],
};
let
json
=
serde_json
::
to_string
(
&
pd
)
.unwrap
();
assert
!
(
json
.contains
(
"
\"
type
\"
:
\"
prefill_decode
\"
"
));
assert
!
(
json
.contains
(
"
\"
prefill_urls
\"
"
));
assert
!
(
json
.contains
(
"
\"
decode_urls
\"
"
));
}
// ============= PolicyConfig Tests =============
#[test]
fn
test_policy_config_name
()
{
assert_eq!
(
PolicyConfig
::
Random
.name
(),
"random"
);
assert_eq!
(
PolicyConfig
::
RoundRobin
.name
(),
"round_robin"
);
let
cache_aware
=
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.8
,
balance_abs_threshold
:
10
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
300
,
max_tree_size
:
1000
,
};
assert_eq!
(
cache_aware
.name
(),
"cache_aware"
);
let
power_of_two
=
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
60
,
};
assert_eq!
(
power_of_two
.name
(),
"power_of_two"
);
}
#[test]
fn
test_policy_config_serialization
()
{
// Test Random
let
random
=
PolicyConfig
::
Random
;
let
json
=
serde_json
::
to_string
(
&
random
)
.unwrap
();
assert_eq!
(
json
,
r#"{"type":"random"}"#
);
// Test CacheAware with all parameters
let
cache_aware
=
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.8
,
balance_abs_threshold
:
10
,
balance_rel_threshold
:
1.5
,
eviction_interval_secs
:
300
,
max_tree_size
:
1000
,
};
let
json
=
serde_json
::
to_string
(
&
cache_aware
)
.unwrap
();
assert
!
(
json
.contains
(
"
\"
type
\"
:
\"
cache_aware
\"
"
));
assert
!
(
json
.contains
(
"
\"
cache_threshold
\"
:0.8"
));
assert
!
(
json
.contains
(
"
\"
balance_abs_threshold
\"
:10"
));
// Test PowerOfTwo
let
power_of_two
=
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
60
,
};
let
json
=
serde_json
::
to_string
(
&
power_of_two
)
.unwrap
();
assert
!
(
json
.contains
(
"
\"
type
\"
:
\"
power_of_two
\"
"
));
assert
!
(
json
.contains
(
"
\"
load_check_interval_secs
\"
:60"
));
}
}
#[test]
fn
test_cache_aware_parameters
()
{
let
cache_aware
=
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.75
,
balance_abs_threshold
:
20
,
balance_rel_threshold
:
2.0
,
eviction_interval_secs
:
600
,
max_tree_size
:
5000
,
};
};
Ok(crate::router::PolicyConfig::PrefillDecodeConfig {
match
cache_aware
{
selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
}
(RoutingMode::Regular { .. }, PolicyConfig::Random) => {
Ok(crate::router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
}
(RoutingMode::Regular { .. }, PolicyConfig::RoundRobin) => {
Ok(crate::router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval_secs,
})
}
(
RoutingMode::Regular { .. },
PolicyConfig
::
CacheAware
{
PolicyConfig
::
CacheAware
{
cache_threshold
,
cache_threshold
,
balance_abs_threshold
,
balance_abs_threshold
,
balance_rel_threshold
,
balance_rel_threshold
,
eviction_interval_secs
,
eviction_interval_secs
,
max_tree_size
,
max_tree_size
,
}
=>
{
assert
!
((
cache_threshold
-
0.75
)
.abs
()
<
0.0001
);
assert_eq!
(
balance_abs_threshold
,
20
);
assert
!
((
balance_rel_threshold
-
2.0
)
.abs
()
<
0.0001
);
assert_eq!
(
eviction_interval_secs
,
600
);
assert_eq!
(
max_tree_size
,
5000
);
}
_
=>
panic!
(
"Expected CacheAware"
),
}
}
#[test]
fn
test_power_of_two_parameters
()
{
let
power_of_two
=
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
120
,
};
match
power_of_two
{
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
,
}
=>
{
assert_eq!
(
load_check_interval_secs
,
120
);
}
_
=>
panic!
(
"Expected PowerOfTwo"
),
}
}
// ============= DiscoveryConfig Tests =============
#[test]
fn
test_discovery_config_default
()
{
let
config
=
DiscoveryConfig
::
default
();
assert
!
(
!
config
.enabled
);
assert
!
(
config
.namespace
.is_none
());
assert_eq!
(
config
.port
,
8000
);
assert_eq!
(
config
.check_interval_secs
,
60
);
assert
!
(
config
.selector
.is_empty
());
assert
!
(
config
.prefill_selector
.is_empty
());
assert
!
(
config
.decode_selector
.is_empty
());
assert_eq!
(
config
.bootstrap_port_annotation
,
"sglang.ai/bootstrap-port"
);
}
#[test]
fn
test_discovery_config_with_selectors
()
{
let
mut
selector
=
HashMap
::
new
();
selector
.insert
(
"app"
.to_string
(),
"sglang"
.to_string
());
selector
.insert
(
"role"
.to_string
(),
"worker"
.to_string
());
let
config
=
DiscoveryConfig
{
enabled
:
true
,
namespace
:
Some
(
"default"
.to_string
()),
port
:
9000
,
check_interval_secs
:
30
,
selector
:
selector
.clone
(),
prefill_selector
:
selector
.clone
(),
decode_selector
:
selector
.clone
(),
bootstrap_port_annotation
:
"custom.io/port"
.to_string
(),
};
assert
!
(
config
.enabled
);
assert_eq!
(
config
.namespace
,
Some
(
"default"
.to_string
()));
assert_eq!
(
config
.port
,
9000
);
assert_eq!
(
config
.selector
.len
(),
2
);
assert_eq!
(
config
.selector
.get
(
"app"
),
Some
(
&
"sglang"
.to_string
()));
}
#[test]
fn
test_discovery_config_namespace
()
{
// Test None namespace (all namespaces)
let
config
=
DiscoveryConfig
{
namespace
:
None
,
..
Default
::
default
()
};
assert
!
(
config
.namespace
.is_none
());
// Test specific namespace
let
config
=
DiscoveryConfig
{
namespace
:
Some
(
"production"
.to_string
()),
..
Default
::
default
()
};
assert_eq!
(
config
.namespace
,
Some
(
"production"
.to_string
()));
}
// ============= MetricsConfig Tests =============
#[test]
fn
test_metrics_config_default
()
{
let
config
=
MetricsConfig
::
default
();
assert_eq!
(
config
.port
,
29000
);
assert_eq!
(
config
.host
,
"127.0.0.1"
);
}
#[test]
fn
test_metrics_config_custom
()
{
let
config
=
MetricsConfig
{
port
:
9090
,
host
:
"0.0.0.0"
.to_string
(),
};
assert_eq!
(
config
.port
,
9090
);
assert_eq!
(
config
.host
,
"0.0.0.0"
);
}
// ============= RouterConfig Utility Methods Tests =============
#[test]
fn
test_mode_type
()
{
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
Regular
{
worker_urls
:
vec!
[],
},
..
Default
::
default
()
};
assert_eq!
(
config
.mode_type
(),
"regular"
);
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[],
decode_urls
:
vec!
[],
},
},
) => Ok(crate::router::PolicyConfig::CacheAwareConfig {
..
Default
::
default
()
cache_threshold: *cache_threshold,
};
balance_abs_threshold: *balance_abs_threshold,
assert_eq!
(
config
.mode_type
(),
"prefill_decode"
);
balance_rel_threshold: *balance_rel_threshold,
}
eviction_interval_secs: *eviction_interval_secs,
max_tree_size: *max_tree_size,
#[test]
timeout_secs: self.worker_startup_timeout_secs,
fn
test_has_service_discovery
()
{
interval_secs: self.worker_startup_check_interval_secs,
let
config
=
RouterConfig
::
default
();
assert
!
(
!
config
.has_service_discovery
());
let
config
=
RouterConfig
{
discovery
:
Some
(
DiscoveryConfig
{
enabled
:
false
,
..
Default
::
default
()
}),
}),
(RoutingMode::Regular { .. }, PolicyConfig::PowerOfTwo { .. }) => {
..
Default
::
default
()
Err(ConfigError::IncompatibleConfig {
};
reason: "PowerOfTwo policy is only supported in PD disaggregated mode"
assert
!
(
!
config
.has_service_discovery
());
.to_string(),
})
let
config
=
RouterConfig
{
discovery
:
Some
(
DiscoveryConfig
{
enabled
:
true
,
..
Default
::
default
()
}),
..
Default
::
default
()
};
assert
!
(
config
.has_service_discovery
());
}
#[test]
fn
test_has_metrics
()
{
let
config
=
RouterConfig
::
default
();
assert
!
(
!
config
.has_metrics
());
let
config
=
RouterConfig
{
metrics
:
Some
(
MetricsConfig
::
default
()),
..
Default
::
default
()
};
assert
!
(
config
.has_metrics
());
}
// ============= Edge Cases =============
#[test]
fn
test_large_worker_lists
()
{
let
large_urls
:
Vec
<
String
>
=
(
0
..
1000
)
.map
(|
i
|
format!
(
"http://worker{}"
,
i
))
.collect
();
let
mode
=
RoutingMode
::
Regular
{
worker_urls
:
large_urls
.clone
(),
};
assert_eq!
(
mode
.worker_count
(),
1000
);
// Test serialization with large list
let
config
=
RouterConfig
{
mode
,
..
Default
::
default
()
};
let
json
=
serde_json
::
to_string
(
&
config
)
.unwrap
();
let
deserialized
:
RouterConfig
=
serde_json
::
from_str
(
&
json
)
.unwrap
();
match
deserialized
.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
assert_eq!
(
worker_urls
.len
(),
1000
);
}
_
=>
panic!
(
"Expected Regular mode"
),
}
}
}
}
#[test]
fn
test_unicode_in_config
()
{
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://работник1"
.to_string
(),
"http://工作者2"
.to_string
()],
},
log_dir
:
Some
(
"/日志/目录"
.to_string
()),
..
Default
::
default
()
};
let
json
=
serde_json
::
to_string
(
&
config
)
.unwrap
();
let
deserialized
:
RouterConfig
=
serde_json
::
from_str
(
&
json
)
.unwrap
();
match
deserialized
.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
assert_eq!
(
worker_urls
[
0
],
"http://работник1"
);
assert_eq!
(
worker_urls
[
1
],
"http://工作者2"
);
}
_
=>
panic!
(
"Expected Regular mode"
),
}
assert_eq!
(
deserialized
.log_dir
,
Some
(
"/日志/目录"
.to_string
()));
}
#[test]
fn
test_empty_string_fields
()
{
let
config
=
RouterConfig
{
host
:
""
.to_string
(),
log_dir
:
Some
(
""
.to_string
()),
log_level
:
Some
(
""
.to_string
()),
..
Default
::
default
()
};
assert_eq!
(
config
.host
,
""
);
assert_eq!
(
config
.log_dir
,
Some
(
""
.to_string
()));
assert_eq!
(
config
.log_level
,
Some
(
""
.to_string
()));
}
// ============= Complex Configuration Tests =============
#[test]
fn
test_full_pd_mode_config
()
{
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
PrefillDecode
{
prefill_urls
:
vec!
[
(
"http://prefill1:8000"
.to_string
(),
Some
(
8001
)),
(
"http://prefill2:8000"
.to_string
(),
None
),
],
decode_urls
:
vec!
[
"http://decode1:8000"
.to_string
(),
"http://decode2:8000"
.to_string
(),
],
},
policy
:
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
30
,
},
host
:
"0.0.0.0"
.to_string
(),
port
:
3000
,
max_payload_size
:
1048576
,
request_timeout_secs
:
120
,
worker_startup_timeout_secs
:
60
,
worker_startup_check_interval_secs
:
5
,
discovery
:
Some
(
DiscoveryConfig
{
enabled
:
true
,
namespace
:
Some
(
"sglang"
.to_string
()),
..
Default
::
default
()
}),
metrics
:
Some
(
MetricsConfig
{
port
:
9090
,
host
:
"0.0.0.0"
.to_string
(),
}),
log_dir
:
Some
(
"/var/log/sglang"
.to_string
()),
log_level
:
Some
(
"info"
.to_string
()),
};
assert
!
(
config
.mode
.is_pd_mode
());
assert_eq!
(
config
.mode
.worker_count
(),
4
);
assert_eq!
(
config
.policy
.name
(),
"power_of_two"
);
assert
!
(
config
.has_service_discovery
());
assert
!
(
config
.has_metrics
());
}
#[test]
fn
test_full_regular_mode_config
()
{
let
mut
selector
=
HashMap
::
new
();
selector
.insert
(
"app"
.to_string
(),
"sglang"
.to_string
());
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1:8000"
.to_string
(),
"http://worker2:8000"
.to_string
(),
"http://worker3:8000"
.to_string
(),
],
},
policy
:
PolicyConfig
::
CacheAware
{
cache_threshold
:
0.9
,
balance_abs_threshold
:
5
,
balance_rel_threshold
:
1.2
,
eviction_interval_secs
:
600
,
max_tree_size
:
10000
,
},
host
:
"0.0.0.0"
.to_string
(),
port
:
3001
,
max_payload_size
:
536870912
,
request_timeout_secs
:
300
,
worker_startup_timeout_secs
:
180
,
worker_startup_check_interval_secs
:
15
,
discovery
:
Some
(
DiscoveryConfig
{
enabled
:
true
,
namespace
:
None
,
port
:
8080
,
check_interval_secs
:
45
,
selector
,
..
Default
::
default
()
}),
metrics
:
Some
(
MetricsConfig
::
default
()),
log_dir
:
None
,
log_level
:
Some
(
"debug"
.to_string
()),
};
assert
!
(
!
config
.mode
.is_pd_mode
());
assert_eq!
(
config
.mode
.worker_count
(),
3
);
assert_eq!
(
config
.policy
.name
(),
"cache_aware"
);
assert
!
(
config
.has_service_discovery
());
assert
!
(
config
.has_metrics
());
}
#[test]
fn
test_config_with_all_options
()
{
let
mut
selectors
=
HashMap
::
new
();
selectors
.insert
(
"env"
.to_string
(),
"prod"
.to_string
());
selectors
.insert
(
"version"
.to_string
(),
"v1"
.to_string
());
let
config
=
RouterConfig
{
mode
:
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"http://worker1"
.to_string
()],
},
policy
:
PolicyConfig
::
RoundRobin
,
host
:
"::1"
.to_string
(),
// IPv6
port
:
8888
,
max_payload_size
:
1024
*
1024
*
512
,
// 512MB
request_timeout_secs
:
900
,
worker_startup_timeout_secs
:
600
,
worker_startup_check_interval_secs
:
20
,
discovery
:
Some
(
DiscoveryConfig
{
enabled
:
true
,
namespace
:
Some
(
"production"
.to_string
()),
port
:
8443
,
check_interval_secs
:
120
,
selector
:
selectors
.clone
(),
prefill_selector
:
selectors
.clone
(),
decode_selector
:
selectors
,
bootstrap_port_annotation
:
"mycompany.io/bootstrap"
.to_string
(),
}),
metrics
:
Some
(
MetricsConfig
{
port
:
9999
,
host
:
"::"
.to_string
(),
// IPv6 any
}),
log_dir
:
Some
(
"/opt/logs/sglang"
.to_string
()),
log_level
:
Some
(
"trace"
.to_string
()),
};
assert
!
(
config
.has_service_discovery
());
assert
!
(
config
.has_metrics
());
assert_eq!
(
config
.mode_type
(),
"regular"
);
// Test round-trip serialization
let
json
=
serde_json
::
to_string_pretty
(
&
config
)
.unwrap
();
let
deserialized
:
RouterConfig
=
serde_json
::
from_str
(
&
json
)
.unwrap
();
assert_eq!
(
deserialized
.host
,
"::1"
);
assert_eq!
(
deserialized
.port
,
8888
);
assert_eq!
(
deserialized
.discovery
.unwrap
()
.namespace
,
Some
(
"production"
.to_string
())
);
}
}
*/
}
}
sgl-router/src/metrics.rs
View file @
1fc455e8
...
@@ -322,3 +322,414 @@ impl RouterMetrics {
...
@@ -322,3 +322,414 @@ impl RouterMetrics {
.set
(
count
as
f64
);
.set
(
count
as
f64
);
}
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
std
::
net
::
TcpListener
;
// ============= PrometheusConfig Tests =============
#[test]
fn
test_prometheus_config_default
()
{
let
config
=
PrometheusConfig
::
default
();
assert_eq!
(
config
.port
,
29000
);
assert_eq!
(
config
.host
,
"0.0.0.0"
);
}
#[test]
fn
test_prometheus_config_custom
()
{
let
config
=
PrometheusConfig
{
port
:
8080
,
host
:
"127.0.0.1"
.to_string
(),
};
assert_eq!
(
config
.port
,
8080
);
assert_eq!
(
config
.host
,
"127.0.0.1"
);
}
#[test]
fn
test_prometheus_config_clone
()
{
let
config
=
PrometheusConfig
{
port
:
9090
,
host
:
"192.168.1.1"
.to_string
(),
};
let
cloned
=
config
.clone
();
assert_eq!
(
cloned
.port
,
config
.port
);
assert_eq!
(
cloned
.host
,
config
.host
);
}
// ============= IP Address Parsing Tests =============
#[test]
fn
test_valid_ipv4_parsing
()
{
let
test_cases
=
vec!
[
"127.0.0.1"
,
"192.168.1.1"
,
"0.0.0.0"
];
for
ip_str
in
test_cases
{
let
config
=
PrometheusConfig
{
port
:
29000
,
host
:
ip_str
.to_string
(),
};
let
ip_addr
:
IpAddr
=
config
.host
.parse
()
.unwrap
();
assert
!
(
matches!
(
ip_addr
,
IpAddr
::
V4
(
_
)));
}
}
#[test]
fn
test_valid_ipv6_parsing
()
{
let
test_cases
=
vec!
[
"::1"
,
"2001:db8::1"
,
"::"
];
for
ip_str
in
test_cases
{
let
config
=
PrometheusConfig
{
port
:
29000
,
host
:
ip_str
.to_string
(),
};
let
ip_addr
:
IpAddr
=
config
.host
.parse
()
.unwrap
();
assert
!
(
matches!
(
ip_addr
,
IpAddr
::
V6
(
_
)));
}
}
#[test]
fn
test_invalid_ip_parsing
()
{
let
test_cases
=
vec!
[
"invalid"
,
"256.256.256.256"
,
"hostname"
];
for
ip_str
in
test_cases
{
let
config
=
PrometheusConfig
{
port
:
29000
,
host
:
ip_str
.to_string
(),
};
let
ip_addr
:
IpAddr
=
config
.host
.parse
()
.unwrap_or
(
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)));
// Should fall back to 0.0.0.0
assert_eq!
(
ip_addr
,
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)));
}
}
// ============= Socket Address Creation Tests =============
#[test]
fn
test_socket_addr_creation
()
{
let
test_cases
=
vec!
[(
"127.0.0.1"
,
8080
),
(
"0.0.0.0"
,
29000
),
(
"::1"
,
9090
)];
for
(
host
,
port
)
in
test_cases
{
let
config
=
PrometheusConfig
{
port
,
host
:
host
.to_string
(),
};
let
ip_addr
:
IpAddr
=
config
.host
.parse
()
.unwrap
();
let
socket_addr
=
SocketAddr
::
new
(
ip_addr
,
config
.port
);
assert_eq!
(
socket_addr
.port
(),
port
);
assert_eq!
(
socket_addr
.ip
()
.to_string
(),
host
);
}
}
#[test]
fn
test_socket_addr_with_different_ports
()
{
let
ports
=
vec!
[
0
,
80
,
8080
,
65535
];
for
port
in
ports
{
let
config
=
PrometheusConfig
{
port
,
host
:
"127.0.0.1"
.to_string
(),
};
let
ip_addr
:
IpAddr
=
config
.host
.parse
()
.unwrap
();
let
socket_addr
=
SocketAddr
::
new
(
ip_addr
,
config
.port
);
assert_eq!
(
socket_addr
.port
(),
port
);
}
}
// ============= Duration Bucket Tests =============
#[test]
fn
test_duration_bucket_values
()
{
let
expected_buckets
=
vec!
[
0.001
,
0.005
,
0.01
,
0.025
,
0.05
,
0.1
,
0.25
,
0.5
,
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
30.0
,
45.0
,
60.0
,
90.0
,
120.0
,
180.0
,
240.0
,
];
// The buckets are defined in start_prometheus function
assert_eq!
(
expected_buckets
.len
(),
20
);
// Verify proper ordering
for
i
in
1
..
expected_buckets
.len
()
{
assert
!
(
expected_buckets
[
i
]
>
expected_buckets
[
i
-
1
]);
}
}
#[test]
fn
test_duration_bucket_coverage
()
{
let
test_cases
=
vec!
[
(
0.0005
,
"sub-millisecond"
),
(
0.005
,
"5ms"
),
(
0.05
,
"50ms"
),
(
1.0
,
"1s"
),
(
10.0
,
"10s"
),
(
60.0
,
"1m"
),
(
240.0
,
"4m"
),
];
let
buckets
=
vec!
[
0.001
,
0.005
,
0.01
,
0.025
,
0.05
,
0.1
,
0.25
,
0.5
,
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
30.0
,
45.0
,
60.0
,
90.0
,
120.0
,
180.0
,
240.0
,
];
for
(
duration
,
label
)
in
test_cases
{
let
bucket_found
=
buckets
.iter
()
.any
(|
&
b
|
((
b
-
duration
)
as
f64
)
.abs
()
<
0.0001
||
b
>
duration
);
assert
!
(
bucket_found
,
"No bucket found for {} ({})"
,
duration
,
label
);
}
}
// ============= Matcher Configuration Tests =============
#[test]
fn
test_duration_suffix_matcher
()
{
let
matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
// Test matching behavior
let
_
matching_metrics
=
vec!
[
"request_duration_seconds"
,
"response_duration_seconds"
,
"sgl_router_request_duration_seconds"
,
];
let
_
non_matching_metrics
=
vec!
[
"duration_total"
,
"duration_seconds_total"
,
"other_metric"
];
// Note: We can't directly test Matcher matching without the internals,
// but we can verify the matcher is created correctly
match
matcher
{
Matcher
::
Suffix
(
suffix
)
=>
assert_eq!
(
suffix
,
"duration_seconds"
),
_
=>
panic!
(
"Expected Suffix matcher"
),
}
}
// ============= Builder Configuration Tests =============
#[test]
fn
test_prometheus_builder_configuration
()
{
// This test verifies the builder configuration without actually starting Prometheus
let
_
config
=
PrometheusConfig
::
default
();
let
duration_matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
let
duration_bucket
=
[
0.001
,
0.005
,
0.01
,
0.025
,
0.05
,
0.1
,
0.25
,
0.5
,
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
30.0
,
45.0
,
60.0
,
90.0
,
120.0
,
180.0
,
240.0
,
];
// Verify bucket configuration
assert_eq!
(
duration_bucket
.len
(),
20
);
// Verify matcher is suffix type
match
duration_matcher
{
Matcher
::
Suffix
(
s
)
=>
assert_eq!
(
s
,
"duration_seconds"
),
_
=>
panic!
(
"Expected Suffix matcher"
),
}
}
// ============= Upkeep Timeout Tests =============
#[test]
fn
test_upkeep_timeout_duration
()
{
let
timeout
=
Duration
::
from_secs
(
5
*
60
);
assert_eq!
(
timeout
.as_secs
(),
300
);
}
// ============= Custom Bucket Tests =============
#[test]
fn
test_custom_buckets_for_different_metrics
()
{
// Test that we can create different bucket configurations
let
request_buckets
=
vec!
[
0.001
,
0.01
,
0.1
,
1.0
,
10.0
];
let
generate_buckets
=
vec!
[
0.1
,
0.5
,
1.0
,
5.0
,
30.0
,
60.0
];
assert_eq!
(
request_buckets
.len
(),
5
);
assert_eq!
(
generate_buckets
.len
(),
6
);
// Verify each set is sorted
for
i
in
1
..
request_buckets
.len
()
{
assert
!
(
request_buckets
[
i
]
>
request_buckets
[
i
-
1
]);
}
for
i
in
1
..
generate_buckets
.len
()
{
assert
!
(
generate_buckets
[
i
]
>
generate_buckets
[
i
-
1
]);
}
}
// ============= RouterMetrics Tests =============
#[test]
fn
test_metrics_static_methods
()
{
// Test that all static methods can be called without panic
RouterMetrics
::
record_request
(
"/generate"
);
RouterMetrics
::
record_request_duration
(
"/generate"
,
Duration
::
from_millis
(
100
));
RouterMetrics
::
record_request_error
(
"/generate"
,
"timeout"
);
RouterMetrics
::
record_retry
(
"/generate"
);
RouterMetrics
::
set_active_workers
(
5
);
RouterMetrics
::
set_worker_health
(
"http://worker1"
,
true
);
RouterMetrics
::
set_worker_load
(
"http://worker1"
,
10
);
RouterMetrics
::
record_processed_request
(
"http://worker1"
);
RouterMetrics
::
record_policy_decision
(
"random"
,
"http://worker1"
);
RouterMetrics
::
record_cache_hit
();
RouterMetrics
::
record_cache_miss
();
RouterMetrics
::
set_tree_size
(
"http://worker1"
,
1000
);
RouterMetrics
::
record_load_balancing_event
();
RouterMetrics
::
set_load_range
(
20
,
5
);
RouterMetrics
::
record_pd_request
(
"/v1/chat/completions"
);
RouterMetrics
::
record_pd_request_duration
(
"/v1/chat/completions"
,
Duration
::
from_secs
(
1
));
RouterMetrics
::
record_pd_prefill_request
(
"http://prefill1"
);
RouterMetrics
::
record_pd_decode_request
(
"http://decode1"
);
RouterMetrics
::
record_pd_error
(
"invalid_request"
);
RouterMetrics
::
record_pd_prefill_error
(
"http://prefill1"
);
RouterMetrics
::
record_pd_decode_error
(
"http://decode1"
);
RouterMetrics
::
record_pd_stream_error
(
"http://decode1"
);
RouterMetrics
::
record_discovery_update
(
3
,
1
);
RouterMetrics
::
record_generate_duration
(
Duration
::
from_secs
(
2
));
RouterMetrics
::
set_running_requests
(
"http://worker1"
,
15
);
}
// ============= Port Availability Tests =============
#[test]
fn
test_port_already_in_use
()
{
// Skip this test if we can't bind to the port
let
port
=
29123
;
// Use a different port to avoid conflicts
if
let
Ok
(
_
listener
)
=
TcpListener
::
bind
((
"127.0.0.1"
,
port
))
{
// Port is available, we can test
let
config
=
PrometheusConfig
{
port
,
host
:
"127.0.0.1"
.to_string
(),
};
// Just verify config is created correctly
assert_eq!
(
config
.port
,
port
);
}
}
// ============= Integration Test Helpers =============
#[test]
fn
test_metrics_endpoint_accessibility
()
{
// This would be an integration test in practice
// Here we just verify the configuration
let
config
=
PrometheusConfig
{
port
:
29000
,
host
:
"127.0.0.1"
.to_string
(),
};
let
ip_addr
:
IpAddr
=
config
.host
.parse
()
.unwrap
();
let
socket_addr
=
SocketAddr
::
new
(
ip_addr
,
config
.port
);
assert_eq!
(
socket_addr
.to_string
(),
"127.0.0.1:29000"
);
}
#[test]
fn
test_concurrent_metric_updates
()
{
// Test that metric updates can be called concurrently
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
Arc
;
use
std
::
thread
;
let
done
=
Arc
::
new
(
AtomicBool
::
new
(
false
));
let
mut
handles
=
vec!
[];
for
i
in
0
..
3
{
let
done_clone
=
done
.clone
();
let
handle
=
thread
::
spawn
(
move
||
{
let
worker
=
format!
(
"http://worker{}"
,
i
);
while
!
done_clone
.load
(
Ordering
::
Relaxed
)
{
RouterMetrics
::
set_worker_load
(
&
worker
,
i
*
10
);
RouterMetrics
::
record_processed_request
(
&
worker
);
thread
::
sleep
(
Duration
::
from_millis
(
1
));
}
});
handles
.push
(
handle
);
}
// Let threads run briefly
thread
::
sleep
(
Duration
::
from_millis
(
10
));
done
.store
(
true
,
Ordering
::
Relaxed
);
// Wait for all threads
for
handle
in
handles
{
handle
.join
()
.unwrap
();
}
// If we get here without panic, concurrent access works
assert
!
(
true
);
}
// ============= Edge Cases Tests =============
#[test]
fn
test_empty_string_metrics
()
{
// Test that empty strings don't cause issues
RouterMetrics
::
record_request
(
""
);
RouterMetrics
::
set_worker_health
(
""
,
true
);
RouterMetrics
::
record_policy_decision
(
""
,
""
);
// If we get here without panic, empty strings are handled
assert
!
(
true
);
}
#[test]
fn
test_very_long_metric_labels
()
{
let
long_label
=
"a"
.repeat
(
1000
);
RouterMetrics
::
record_request
(
&
long_label
);
RouterMetrics
::
set_worker_health
(
&
long_label
,
false
);
// If we get here without panic, long labels are handled
assert
!
(
true
);
}
#[test]
fn
test_special_characters_in_labels
()
{
let
special_labels
=
vec!
[
"test/with/slashes"
,
"test-with-dashes"
,
"test_with_underscores"
,
"test.with.dots"
,
"test:with:colons"
,
];
for
label
in
special_labels
{
RouterMetrics
::
record_request
(
label
);
RouterMetrics
::
set_worker_health
(
label
,
true
);
}
// If we get here without panic, special characters are handled
assert
!
(
true
);
}
#[test]
fn
test_extreme_metric_values
()
{
// Test extreme values
RouterMetrics
::
set_active_workers
(
0
);
RouterMetrics
::
set_active_workers
(
usize
::
MAX
);
RouterMetrics
::
set_worker_load
(
"worker"
,
0
);
RouterMetrics
::
set_worker_load
(
"worker"
,
usize
::
MAX
);
RouterMetrics
::
record_request_duration
(
"route"
,
Duration
::
from_nanos
(
1
));
RouterMetrics
::
record_request_duration
(
"route"
,
Duration
::
from_secs
(
86400
));
// 24 hours
// If we get here without panic, extreme values are handled
assert
!
(
true
);
}
}
sgl-router/src/routers/pd_types.rs
View file @
1fc455e8
...
@@ -58,7 +58,7 @@ pub enum PDSelectionPolicy {
...
@@ -58,7 +58,7 @@ pub enum PDSelectionPolicy {
},
},
}
}
// Bootstrap types from PDLB
// Bootstrap types from PDLB
#[derive(Debug,
Deserialize,
Serialize)]
#[derive(Debug,
Deserialize,
Serialize
,
PartialEq
)]
#[serde(untagged)]
#[serde(untagged)]
pub
enum
SingleOrBatch
<
T
>
{
pub
enum
SingleOrBatch
<
T
>
{
Single
(
T
),
Single
(
T
),
...
...
sgl-router/src/routers/request_adapter.rs
View file @
1fc455e8
...
@@ -211,6 +211,7 @@ impl ToPdRequest for ChatCompletionRequest {
...
@@ -211,6 +211,7 @@ impl ToPdRequest for ChatCompletionRequest {
self
.temperature
=>
"temperature"
,
self
.temperature
=>
"temperature"
,
self
.top_p
=>
"top_p"
,
self
.top_p
=>
"top_p"
,
self
.n
=>
"n"
,
self
.n
=>
"n"
,
self
.stream_options
=>
"stream_options"
,
self
.stop
=>
"stop"
,
self
.stop
=>
"stop"
,
self
.max_tokens
=>
"max_tokens"
,
self
.max_tokens
=>
"max_tokens"
,
self
.max_completion_tokens
=>
"max_completion_tokens"
,
self
.max_completion_tokens
=>
"max_completion_tokens"
,
...
@@ -262,3 +263,1015 @@ pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
...
@@ -262,3 +263,1015 @@ pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
impl
RouteableRequest
for
GenerateRequest
{}
impl
RouteableRequest
for
GenerateRequest
{}
impl
RouteableRequest
for
CompletionRequest
{}
impl
RouteableRequest
for
CompletionRequest
{}
impl
RouteableRequest
for
ChatCompletionRequest
{}
impl
RouteableRequest
for
ChatCompletionRequest
{}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
openai_api_types
::
*
;
use
serde_json
::
json
;
use
std
::
collections
::
HashMap
;
// ============= GenerateRequest to_pd_request Tests =============
#[test]
fn
test_generate_to_pd_request_with_text_only
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"Hello world"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
// Check text field conversion
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Hello world"
));
assert
!
(
pd_req
.input_ids
.is_none
());
// Check bootstrap fields are None
assert
!
(
pd_req
.bootstrap_host
.is_none
());
assert
!
(
pd_req
.bootstrap_port
.is_none
());
assert
!
(
pd_req
.bootstrap_room
.is_none
());
// Check stream flag
assert_eq!
(
pd_req
.stream
,
false
);
// Check other fields
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
false
)));
assert_eq!
(
other
.get
(
"return_logprob"
),
Some
(
&
json!
(
false
)));
}
#[test]
fn
test_generate_to_pd_request_with_prompt_string
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
String
(
"Test prompt"
.to_string
())),
input_ids
:
None
,
stream
:
true
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
true
,
};
let
pd_req
=
req
.to_pd_request
();
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Test prompt"
));
assert
!
(
pd_req
.input_ids
.is_none
());
assert_eq!
(
pd_req
.stream
,
true
);
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"return_logprob"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_generate_to_pd_request_with_prompt_array
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
Array
(
vec!
[
"Prompt 1"
.to_string
(),
"Prompt 2"
.to_string
(),
"Prompt 3"
.to_string
(),
])),
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
match
pd_req
.text
{
Some
(
SingleOrBatch
::
Batch
(
ref
batch
))
=>
{
assert_eq!
(
batch
.len
(),
3
);
assert_eq!
(
batch
[
0
],
"Prompt 1"
);
assert_eq!
(
batch
[
1
],
"Prompt 2"
);
assert_eq!
(
batch
[
2
],
"Prompt 3"
);
}
_
=>
panic!
(
"Expected batch text"
),
}
}
#[test]
fn
test_generate_to_pd_request_with_single_input_ids
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
100
,
200
,
300
,
400
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
assert
!
(
pd_req
.text
.is_none
());
assert
!
(
matches!
(
pd_req
.input_ids
,
Some
(
SingleOrBatch
::
Single
(
ref
ids
))
if
ids
==
&
vec!
[
100
,
200
,
300
,
400
]
));
}
#[test]
fn
test_generate_to_pd_request_with_batch_input_ids
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
None
,
input_ids
:
Some
(
InputIds
::
Batch
(
vec!
[
vec!
[
1
,
2
,
3
],
vec!
[
4
,
5
,
6
,
7
],
vec!
[
8
,
9
],
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
match
pd_req
.input_ids
{
Some
(
SingleOrBatch
::
Batch
(
ref
batch
))
=>
{
assert_eq!
(
batch
.len
(),
3
);
assert_eq!
(
batch
[
0
],
vec!
[
1
,
2
,
3
]);
assert_eq!
(
batch
[
1
],
vec!
[
4
,
5
,
6
,
7
]);
assert_eq!
(
batch
[
2
],
vec!
[
8
,
9
]);
}
_
=>
panic!
(
"Expected batch input_ids"
),
}
}
#[test]
fn
test_generate_to_pd_request_priority_text_over_prompt
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"SGLang text"
.to_string
()),
prompt
:
Some
(
StringOrArray
::
String
(
"OpenAI prompt"
.to_string
())),
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
1
,
2
,
3
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
// text should take priority
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"SGLang text"
));
assert
!
(
pd_req
.input_ids
.is_none
());
}
#[test]
fn
test_generate_to_pd_request_priority_prompt_over_input_ids
()
{
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
String
(
"OpenAI prompt"
.to_string
())),
input_ids
:
Some
(
InputIds
::
Single
(
vec!
[
1
,
2
,
3
])),
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
// prompt should take priority over input_ids
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"OpenAI prompt"
));
assert
!
(
pd_req
.input_ids
.is_none
());
}
#[test]
fn
test_generate_to_pd_request_with_parameters
()
{
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
0.95
),
seed
:
Some
(
12345
),
stop
:
Some
(
vec!
[
"END"
.to_string
(),
"STOP"
.to_string
()]),
repetition_penalty
:
Some
(
1.1
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Check that max_new_tokens and temperature were extracted to top level
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
100
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.8
<
0.0001
);
// Check that other parameters remain under "parameters"
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
assert
!
(
params
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.95
<
0.0001
);
assert_eq!
(
params
.get
(
"seed"
),
Some
(
&
json!
(
12345
)));
assert_eq!
(
params
.get
(
"stop"
),
Some
(
&
json!
(
vec!
[
"END"
,
"STOP"
])));
assert
!
(
params
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.1
<
0.0001
);
}
#[test]
fn
test_generate_to_pd_request_with_sampling_params
()
{
let
sampling
=
SamplingParams
{
max_new_tokens
:
Some
(
200
),
temperature
:
Some
(
0.7
),
top_p
:
Some
(
0.9
),
top_k
:
Some
(
50
),
frequency_penalty
:
Some
(
0.1
),
presence_penalty
:
Some
(
0.2
),
repetition_penalty
:
Some
(
1.05
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
Some
(
sampling
),
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Check extracted top-level fields
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
200
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.7
<
0.0001
);
// Check full sampling_params is preserved
let
sampling
=
other
.get
(
"sampling_params"
)
.unwrap
()
.as_object
()
.unwrap
();
assert_eq!
(
sampling
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
200
)));
assert
!
(
sampling
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.7
<
0.0001
);
assert
!
(
sampling
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.9
<
0.0001
);
assert_eq!
(
sampling
.get
(
"top_k"
),
Some
(
&
json!
(
50
)));
assert
!
(
sampling
.get
(
"frequency_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.1
<
0.0001
);
assert
!
(
sampling
.get
(
"presence_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.2
<
0.0001
);
}
#[test]
fn
test_generate_to_pd_request_sampling_params_override_parameters
()
{
// When both parameters and sampling_params have max_new_tokens/temperature,
// sampling_params should take precedence (processed last)
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
temperature
:
Some
(
0.5
),
..
Default
::
default
()
};
let
sampling
=
SamplingParams
{
max_new_tokens
:
Some
(
200
),
temperature
:
Some
(
0.9
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
Some
(
sampling
),
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should use values from sampling_params since they're processed last
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
200
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.9
<
0.0001
);
}
#[test]
fn
test_generate_to_pd_request_empty_parameters
()
{
let
params
=
GenerateParameters
::
default
();
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should not have parameters field if all values are None/default
assert
!
(
!
other
.contains_key
(
"parameters"
));
assert
!
(
!
other
.contains_key
(
"max_new_tokens"
));
assert
!
(
!
other
.contains_key
(
"temperature"
));
}
#[test]
fn
test_generate_to_pd_request_all_fields
()
{
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
150
),
temperature
:
Some
(
0.6
),
top_k
:
Some
(
40
),
..
Default
::
default
()
};
let
sampling
=
SamplingParams
{
max_new_tokens
:
Some
(
250
),
// Will override parameters
temperature
:
Some
(
0.8
),
// Will override parameters
presence_penalty
:
Some
(
0.1
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"Complex test"
.to_string
()),
prompt
:
Some
(
StringOrArray
::
String
(
"Ignored prompt"
.to_string
())),
input_ids
:
None
,
stream
:
true
,
parameters
:
Some
(
params
),
sampling_params
:
Some
(
sampling
),
return_logprob
:
true
,
};
let
pd_req
=
req
.to_pd_request
();
// Verify all fields
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Complex test"
));
assert
!
(
pd_req
.input_ids
.is_none
());
assert_eq!
(
pd_req
.stream
,
true
);
assert
!
(
pd_req
.bootstrap_host
.is_none
());
assert
!
(
pd_req
.bootstrap_port
.is_none
());
assert
!
(
pd_req
.bootstrap_room
.is_none
());
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"return_logprob"
),
Some
(
&
json!
(
true
)));
// Sampling params override parameters
assert_eq!
(
other
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
250
)));
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.8
<
0.0001
);
assert
!
(
other
.contains_key
(
"parameters"
));
assert
!
(
other
.contains_key
(
"sampling_params"
));
}
// ============= CompletionRequest to_pd_request Tests =============
#[test]
fn
test_completion_to_pd_request_basic
()
{
let
req
=
CompletionRequest
{
model
:
"gpt-3.5-turbo"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"Complete this sentence"
.to_string
()),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
assert
!
(
matches!
(
pd_req
.text
,
Some
(
SingleOrBatch
::
Single
(
ref
s
))
if
s
==
"Complete this sentence"
)
);
assert
!
(
pd_req
.input_ids
.is_none
());
assert_eq!
(
pd_req
.stream
,
false
);
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"model"
),
Some
(
&
json!
(
"gpt-3.5-turbo"
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
false
)));
}
#[test]
fn
test_completion_to_pd_request_array_prompt
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
Array
(
vec!
[
"First prompt"
.to_string
(),
"Second prompt"
.to_string
(),
]),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
match
pd_req
.text
{
Some
(
SingleOrBatch
::
Batch
(
ref
batch
))
=>
{
assert_eq!
(
batch
.len
(),
2
);
assert_eq!
(
batch
[
0
],
"First prompt"
);
assert_eq!
(
batch
[
1
],
"Second prompt"
);
}
_
=>
panic!
(
"Expected batch text"
),
}
}
#[test]
fn
test_completion_to_pd_request_parameter_mapping
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test"
.to_string
()),
max_tokens
:
Some
(
150
),
// -> max_new_tokens
temperature
:
Some
(
0.75
),
top_p
:
Some
(
0.92
),
n
:
Some
(
3
),
// -> best_of
stream
:
true
,
stream_options
:
None
,
logprobs
:
Some
(
10
),
// -> top_n_tokens
echo
:
true
,
// -> return_full_text
stop
:
Some
(
StringOrArray
::
Array
(
vec!
[
"
\\
n"
.to_string
(),
"END"
.to_string
(),
])),
presence_penalty
:
Some
(
0.5
),
// -> repetition_penalty = 1.5
frequency_penalty
:
Some
(
0.2
),
best_of
:
Some
(
5
),
logit_bias
:
None
,
user
:
Some
(
"user123"
.to_string
()),
seed
:
Some
(
42
),
suffix
:
Some
(
"..."
.to_string
()),
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
// Check parameter mappings
assert_eq!
(
params
.get
(
"max_new_tokens"
),
Some
(
&
json!
(
150
)));
assert
!
(
params
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.75
<
0.0001
);
assert
!
(
params
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.92
<
0.0001
);
assert_eq!
(
params
.get
(
"best_of"
),
Some
(
&
json!
(
3
)));
assert_eq!
(
params
.get
(
"top_n_tokens"
),
Some
(
&
json!
(
10
)));
assert_eq!
(
params
.get
(
"return_full_text"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
params
.get
(
"stop"
),
Some
(
&
json!
(
vec!
[
"
\\
n"
,
"END"
])));
assert
!
(
params
.get
(
"repetition_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
1.5
<
0.0001
);
assert_eq!
(
params
.get
(
"seed"
),
Some
(
&
json!
(
42
)));
// Check other fields
assert_eq!
(
other
.get
(
"model"
),
Some
(
&
json!
(
"test"
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
}
#[test]
fn
test_completion_to_pd_request_stop_string
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test"
.to_string
()),
stop
:
Some
(
StringOrArray
::
String
(
"STOP"
.to_string
())),
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
// Single string stop should be converted to array
assert_eq!
(
params
.get
(
"stop"
),
Some
(
&
json!
(
vec!
[
"STOP"
])));
}
#[test]
fn
test_completion_to_pd_request_no_presence_penalty
()
{
let
req
=
CompletionRequest
{
model
:
"test"
.to_string
(),
prompt
:
StringOrArray
::
String
(
"test"
.to_string
()),
presence_penalty
:
None
,
max_tokens
:
None
,
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
logprobs
:
None
,
echo
:
false
,
stop
:
None
,
frequency_penalty
:
None
,
best_of
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
suffix
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
let
params
=
other
.get
(
"parameters"
)
.unwrap
()
.as_object
()
.unwrap
();
// Should not have repetition_penalty if presence_penalty is None
assert
!
(
!
params
.contains_key
(
"repetition_penalty"
));
}
// ============= ChatCompletionRequest to_pd_request Tests =============
#[test]
fn
test_chat_to_pd_request_basic
()
{
let
messages
=
vec!
[
ChatMessage
::
System
{
role
:
"system"
.to_string
(),
content
:
"You are a helpful assistant"
.to_string
(),
name
:
None
,
},
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Hello!"
.to_string
()),
name
:
None
,
},
];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4"
.to_string
(),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
assert_eq!
(
pd_req
.stream
,
false
);
assert
!
(
pd_req
.bootstrap_host
.is_none
());
assert
!
(
pd_req
.bootstrap_port
.is_none
());
assert
!
(
pd_req
.bootstrap_room
.is_none
());
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert
!
(
other
.contains_key
(
"messages"
));
assert_eq!
(
other
.get
(
"model"
),
Some
(
&
json!
(
"gpt-4"
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
false
)));
// Check messages are preserved
let
messages
=
other
.get
(
"messages"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
messages
.len
(),
2
);
}
#[test]
fn
test_chat_to_pd_request_with_all_optional_fields
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
Some
(
"test_user"
.to_string
()),
}];
let
mut
logit_bias
=
HashMap
::
new
();
logit_bias
.insert
(
"50256"
.to_string
(),
-
100
);
let
tool
=
Tool
{
tool_type
:
"function"
.to_string
(),
function
:
Function
{
name
:
"get_weather"
.to_string
(),
description
:
Some
(
"Get weather info"
.to_string
()),
parameters
:
json!
({
"type"
:
"object"
}),
},
};
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4"
.to_string
(),
temperature
:
Some
(
0.8
),
top_p
:
Some
(
0.95
),
n
:
Some
(
2
),
stream
:
true
,
stream_options
:
Some
(
StreamOptions
{
include_usage
:
Some
(
true
),
}),
stop
:
Some
(
StringOrArray
::
String
(
"
\\
n
\\
n"
.to_string
())),
max_tokens
:
Some
(
200
),
max_completion_tokens
:
Some
(
150
),
presence_penalty
:
Some
(
0.1
),
frequency_penalty
:
Some
(
0.2
),
logit_bias
:
Some
(
logit_bias
),
logprobs
:
true
,
top_logprobs
:
Some
(
5
),
user
:
Some
(
"user456"
.to_string
()),
seed
:
Some
(
12345
),
response_format
:
Some
(
ResponseFormat
::
JsonObject
),
tools
:
Some
(
vec!
[
tool
]),
tool_choice
:
Some
(
ToolChoice
::
Auto
),
parallel_tool_calls
:
Some
(
false
),
functions
:
None
,
function_call
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Check all fields are preserved
assert
!
(
other
.get
(
"temperature"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.8
<
0.0001
);
assert
!
(
other
.get
(
"top_p"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.95
<
0.0001
);
assert_eq!
(
other
.get
(
"n"
),
Some
(
&
json!
(
2
)));
assert_eq!
(
other
.get
(
"stream"
),
Some
(
&
json!
(
true
)));
assert
!
(
other
.contains_key
(
"stream_options"
));
assert
!
(
other
.contains_key
(
"stop"
));
assert_eq!
(
other
.get
(
"max_tokens"
),
Some
(
&
json!
(
200
)));
assert_eq!
(
other
.get
(
"max_completion_tokens"
),
Some
(
&
json!
(
150
)));
assert
!
(
other
.get
(
"presence_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.1
<
0.0001
);
assert
!
(
other
.get
(
"frequency_penalty"
)
.unwrap
()
.as_f64
()
.unwrap
()
-
0.2
<
0.0001
);
assert
!
(
other
.contains_key
(
"logit_bias"
));
assert_eq!
(
other
.get
(
"logprobs"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"top_logprobs"
),
Some
(
&
json!
(
5
)));
assert_eq!
(
other
.get
(
"user"
),
Some
(
&
json!
(
"user456"
)));
assert_eq!
(
other
.get
(
"seed"
),
Some
(
&
json!
(
12345
)));
assert
!
(
other
.contains_key
(
"response_format"
));
assert
!
(
other
.contains_key
(
"tools"
));
assert
!
(
other
.contains_key
(
"tool_choice"
));
assert_eq!
(
other
.get
(
"parallel_tool_calls"
),
Some
(
&
json!
(
false
)));
}
#[test]
fn
test_chat_to_pd_request_multimodal_content
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Parts
(
vec!
[
ContentPart
::
Text
{
text
:
"What's in this image?"
.to_string
(),
},
ContentPart
::
ImageUrl
{
image_url
:
ImageUrl
{
url
:
"https://example.com/image.jpg"
.to_string
(),
detail
:
Some
(
"high"
.to_string
()),
},
},
]),
name
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-4-vision"
.to_string
(),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Messages with multimodal content should be preserved
assert
!
(
other
.contains_key
(
"messages"
));
let
messages
=
other
.get
(
"messages"
)
.unwrap
()
.as_array
()
.unwrap
();
assert_eq!
(
messages
.len
(),
1
);
// Verify the message structure is preserved
let
msg
=
&
messages
[
0
];
assert_eq!
(
msg
[
"role"
],
"user"
);
assert
!
(
msg
[
"content"
]
.is_array
());
}
#[test]
fn
test_chat_to_pd_request_logprobs_boolean
()
{
let
messages
=
vec!
[
ChatMessage
::
User
{
role
:
"user"
.to_string
(),
content
:
UserMessageContent
::
Text
(
"Test"
.to_string
()),
name
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"test"
.to_string
(),
logprobs
:
true
,
// Boolean logprobs flag
top_logprobs
:
Some
(
3
),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
assert_eq!
(
other
.get
(
"logprobs"
),
Some
(
&
json!
(
true
)));
assert_eq!
(
other
.get
(
"top_logprobs"
),
Some
(
&
json!
(
3
)));
}
#[test]
fn
test_chat_to_pd_request_minimal_fields
()
{
let
messages
=
vec!
[
ChatMessage
::
Assistant
{
role
:
"assistant"
.to_string
(),
content
:
Some
(
"I can help with that."
.to_string
()),
name
:
None
,
tool_calls
:
None
,
function_call
:
None
,
}];
let
req
=
ChatCompletionRequest
{
messages
,
model
:
"gpt-3.5-turbo"
.to_string
(),
temperature
:
None
,
top_p
:
None
,
n
:
None
,
stream
:
false
,
stream_options
:
None
,
stop
:
None
,
max_tokens
:
None
,
max_completion_tokens
:
None
,
presence_penalty
:
None
,
frequency_penalty
:
None
,
logit_bias
:
None
,
logprobs
:
false
,
top_logprobs
:
None
,
user
:
None
,
seed
:
None
,
response_format
:
None
,
tools
:
None
,
tool_choice
:
None
,
parallel_tool_calls
:
None
,
functions
:
None
,
function_call
:
None
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should only have required fields
assert
!
(
other
.contains_key
(
"messages"
));
assert
!
(
other
.contains_key
(
"model"
));
assert
!
(
other
.contains_key
(
"stream"
));
// Optional fields should not be present
assert
!
(
!
other
.contains_key
(
"temperature"
));
assert
!
(
!
other
.contains_key
(
"top_p"
));
assert
!
(
!
other
.contains_key
(
"max_tokens"
));
assert
!
(
!
other
.contains_key
(
"stop"
));
}
#[test]
fn
test_routeable_request_to_json
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
json
=
req
.to_json
()
.unwrap
();
assert_eq!
(
json
[
"text"
],
"test"
);
assert_eq!
(
json
[
"stream"
],
false
);
}
// ============= Macro Tests =============
#[test]
fn
test_insert_if_some_macro
()
{
let
mut
map
=
serde_json
::
Map
::
new
();
let
some_value
:
Option
<
i32
>
=
Some
(
42
);
let
none_value
:
Option
<
i32
>
=
None
;
insert_if_some!
(
map
,
some_value
=>
"present"
,
none_value
=>
"absent"
);
assert_eq!
(
map
.get
(
"present"
),
Some
(
&
json!
(
42
)));
assert
!
(
!
map
.contains_key
(
"absent"
));
}
#[test]
fn
test_insert_value_macro
()
{
let
mut
map
=
serde_json
::
Map
::
new
();
let
value1
=
"test"
;
let
value2
=
42
;
insert_value!
(
map
,
value1
=>
"string_field"
,
value2
=>
"int_field"
);
assert_eq!
(
map
.get
(
"string_field"
),
Some
(
&
json!
(
"test"
)));
assert_eq!
(
map
.get
(
"int_field"
),
Some
(
&
json!
(
42
)));
}
// ============= Edge Cases and Error Handling =============
#[test]
fn
test_null_value_handling
()
{
let
params
=
GenerateParameters
{
max_new_tokens
:
None
,
temperature
:
None
,
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Should not have parameters field if all fields are None
assert
!
(
!
other
.contains_key
(
"parameters"
));
}
#[test]
fn
test_large_batch_conversion
()
{
let
large_batch
:
Vec
<
String
>
=
(
0
..
1000
)
.map
(|
i
|
format!
(
"item_{}"
,
i
))
.collect
();
let
req
=
GenerateRequest
{
text
:
None
,
prompt
:
Some
(
StringOrArray
::
Array
(
large_batch
.clone
())),
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
if
let
Some
(
SingleOrBatch
::
Batch
(
batch
))
=
pd_req
.text
{
assert_eq!
(
batch
.len
(),
1000
);
assert_eq!
(
batch
[
0
],
"item_0"
);
assert_eq!
(
batch
[
999
],
"item_999"
);
}
else
{
panic!
(
"Expected batch text"
);
}
}
#[test]
fn
test_unicode_string_handling
()
{
let
unicode_text
=
"Hello 世界 🌍 नमस्ते мир"
.to_string
();
let
req
=
GenerateRequest
{
text
:
Some
(
unicode_text
.clone
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
if
let
Some
(
SingleOrBatch
::
Single
(
text
))
=
pd_req
.text
{
assert_eq!
(
text
,
unicode_text
);
}
else
{
panic!
(
"Expected single text"
);
}
}
#[test]
fn
test_deeply_nested_parameters
()
{
let
mut
nested_params
=
serde_json
::
Map
::
new
();
nested_params
.insert
(
"nested"
.to_string
(),
json!
({
"level1"
:
{
"level2"
:
{
"level3"
:
"value"
}
}
}),
);
let
params
=
GenerateParameters
{
max_new_tokens
:
Some
(
100
),
..
Default
::
default
()
};
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
Some
(
params
),
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
let
other
=
pd_req
.other
.as_object
()
.unwrap
();
// Parameters should be preserved even with nested structures
assert
!
(
other
.contains_key
(
"max_new_tokens"
));
}
// ============= Bootstrap Field Tests =============
#[test]
fn
test_bootstrap_fields_none
()
{
let
req
=
GenerateRequest
{
text
:
Some
(
"test"
.to_string
()),
prompt
:
None
,
input_ids
:
None
,
stream
:
false
,
parameters
:
None
,
sampling_params
:
None
,
return_logprob
:
false
,
};
let
pd_req
=
req
.to_pd_request
();
assert_eq!
(
pd_req
.bootstrap_host
,
None
);
assert_eq!
(
pd_req
.bootstrap_port
,
None
);
assert_eq!
(
pd_req
.bootstrap_room
,
None
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment