"vscode:/vscode.git/clone" did not exist on "cea0e4b9ead83adec919bb21c62360491317d351"
Unverified Commit c1930022 authored by Aidan Cooper's avatar Aidan Cooper Committed by GitHub
Browse files

Add support for VertexAI safety settings (#624)

parent fe3be159
import os
import warnings
from typing import List, Optional, Union
import numpy as np
from typing import Optional
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
......@@ -16,12 +14,15 @@ try:
GenerativeModel,
Image,
)
from vertexai.generative_models._generative_models import SafetySettingsType
except ImportError as e:
GenerativeModel = e
class VertexAI(BaseBackend):
def __init__(self, model_name):
def __init__(
self, model_name, safety_settings: Optional[SafetySettingsType] = None
):
super().__init__()
if isinstance(GenerativeModel, Exception):
......@@ -33,6 +34,7 @@ class VertexAI(BaseBackend):
self.model_name = model_name
self.chat_template = get_chat_template("default")
self.safety_settings = safety_settings
def get_chat_template(self):
return self.chat_template
......@@ -54,6 +56,7 @@ class VertexAI(BaseBackend):
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
comp = ret.text
......@@ -78,6 +81,7 @@ class VertexAI(BaseBackend):
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
for ret in generator:
yield ret.text, {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment