Unverified Commit c1930022 authored by Aidan Cooper's avatar Aidan Cooper Committed by GitHub
Browse files

Add support for VertexAI safety settings (#624)

parent fe3be159
import os import os
import warnings import warnings
from typing import List, Optional, Union from typing import Optional
import numpy as np
from sglang.backend.base_backend import BaseBackend from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template from sglang.lang.chat_template import get_chat_template
...@@ -16,12 +14,15 @@ try: ...@@ -16,12 +14,15 @@ try:
GenerativeModel, GenerativeModel,
Image, Image,
) )
from vertexai.generative_models._generative_models import SafetySettingsType
except ImportError as e: except ImportError as e:
GenerativeModel = e GenerativeModel = e
class VertexAI(BaseBackend): class VertexAI(BaseBackend):
def __init__(self, model_name): def __init__(
self, model_name, safety_settings: Optional[SafetySettingsType] = None
):
super().__init__() super().__init__()
if isinstance(GenerativeModel, Exception): if isinstance(GenerativeModel, Exception):
...@@ -33,6 +34,7 @@ class VertexAI(BaseBackend): ...@@ -33,6 +34,7 @@ class VertexAI(BaseBackend):
self.model_name = model_name self.model_name = model_name
self.chat_template = get_chat_template("default") self.chat_template = get_chat_template("default")
self.safety_settings = safety_settings
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
...@@ -54,6 +56,7 @@ class VertexAI(BaseBackend): ...@@ -54,6 +56,7 @@ class VertexAI(BaseBackend):
ret = GenerativeModel(self.model_name).generate_content( ret = GenerativeModel(self.model_name).generate_content(
prompt, prompt,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
) )
comp = ret.text comp = ret.text
...@@ -78,6 +81,7 @@ class VertexAI(BaseBackend): ...@@ -78,6 +81,7 @@ class VertexAI(BaseBackend):
prompt, prompt,
stream=True, stream=True,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
) )
for ret in generator: for ret in generator:
yield ret.text, {} yield ret.text, {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment