Commit 4f6f63cd authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: add rust based tokenizer

parent 53163693
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallType {
Function,
}
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct CalledFunction {
pub name: String,
pub arguments: String,
}
// #[cfg_attr(feature = "pyo3_macros", pyo3::pyclass)]
// #[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Clone, Debug, serde::Serialize)]
pub struct ToolCallResponse {
pub id: String,
#[serde(rename = "type")]
pub tp: ToolCallType,
pub function: CalledFunction,
}
...@@ -40,7 +40,7 @@ use super::{ ...@@ -40,7 +40,7 @@ use super::{
validate_logit_bias, ContentProvider, OpenAISamplingOptionsProvider, validate_logit_bias, ContentProvider, OpenAISamplingOptionsProvider,
OpenAIStopConditionsProvider, OpenAIStopConditionsProvider,
}; };
// use crate::AnnotationsProvider; use triton_distributed::protocols::annotated::AnnotationsProvider;
/// Request object which is used to generate chat completions. /// Request object which is used to generate chat completions.
#[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)] #[derive(Serialize, Deserialize, Builder, Validate, Debug, Clone)]
...@@ -791,21 +791,21 @@ impl NvExtProvider for ChatCompletionRequest { ...@@ -791,21 +791,21 @@ impl NvExtProvider for ChatCompletionRequest {
} }
} }
// impl AnnotationsProvider for ChatCompletionRequest { impl AnnotationsProvider for ChatCompletionRequest {
// fn annotations(&self) -> Option<Vec<String>> { fn annotations(&self) -> Option<Vec<String>> {
// self.nvext self.nvext
// .as_ref() .as_ref()
// .and_then(|nvext| nvext.annotations.clone()) .and_then(|nvext| nvext.annotations.clone())
// } }
// fn has_annotation(&self, annotation: &str) -> bool { fn has_annotation(&self, annotation: &str) -> bool {
// self.nvext self.nvext
// .as_ref() .as_ref()
// .and_then(|nvext| nvext.annotations.as_ref()) .and_then(|nvext| nvext.annotations.as_ref())
// .map(|annotations| annotations.contains(&annotation.to_string())) .map(|annotations| annotations.contains(&annotation.to_string()))
// .unwrap_or(false) .unwrap_or(false)
// } }
// } }
impl OpenAISamplingOptionsProvider for ChatCompletionRequest { impl OpenAISamplingOptionsProvider for ChatCompletionRequest {
fn get_temperature(&self) -> Option<f32> { fn get_temperature(&self) -> Option<f32> {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
{
"bos_token_id": 1,
"eos_token_id": 2,
"pad_token_id": 0,
"max_length": 2048,
"transformers_version": "4.31.0.dev0"
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment