Unverified Commit c86f58d3 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add support for Gemma (#1583)

parent fa8a8e05
......@@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension):
exclude=None,
matcher=None,
):
if isinstance(data, Response):
data = data.dict()
if isinstance(data, List):
data = [d.dict() for d in data]
......
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.8671875,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.4375,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8203125,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23242188,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.08544922,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.9375,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.671875,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.40429688,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.1875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": 0,
"tokens": [
{
"id": 7539,
"logprob": -0.73046875,
"special": false,
"text": " forms"
},
{
"id": 708,
"logprob": 0.0,
"special": false,
"text": " are"
},
{
"id": 671,
"logprob": -1.703125,
"special": false,
"text": " an"
},
{
"id": 8727,
"logprob": 0.0,
"special": false,
"text": " essential"
},
{
"id": 1702,
"logprob": 0.0,
"special": false,
"text": " part"
},
{
"id": 576,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 573,
"logprob": 0.0,
"special": false,
"text": " the"
},
{
"id": 11859,
"logprob": -1.6953125,
"special": false,
"text": " lab"
},
{
"id": 2185,
"logprob": -1.3125,
"special": false,
"text": " process"
},
{
"id": 578,
"logprob": -1.5,
"special": false,
"text": " and"
}
],
"top_tokens": null
},
"generated_text": "Test request forms are an essential part of the lab process and"
}
[
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
},
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 2,
"logprob": null,
"text": "<bos>"
},
{
"id": 2015,
"logprob": -10.0,
"text": "Test"
},
{
"id": 3853,
"logprob": -10.875,
"text": " request"
}
],
"seed": null,
"tokens": [
{
"id": 1736,
"logprob": -2.09375,
"special": false,
"text": " form"
},
{
"id": 109,
"logprob": -1.9140625,
"special": false,
"text": "\n\n"
},
{
"id": 651,
"logprob": -2.453125,
"special": false,
"text": "The"
},
{
"id": 2121,
"logprob": -1.8984375,
"special": false,
"text": " test"
},
{
"id": 3853,
"logprob": -0.23535156,
"special": false,
"text": " request"
},
{
"id": 1736,
"logprob": -0.091308594,
"special": false,
"text": " form"
},
{
"id": 603,
"logprob": -0.96875,
"special": false,
"text": " is"
},
{
"id": 1671,
"logprob": -1.6484375,
"special": false,
"text": " used"
},
{
"id": 577,
"logprob": -0.43164062,
"special": false,
"text": " to"
},
{
"id": 3853,
"logprob": -1.2421875,
"special": false,
"text": " request"
}
],
"top_tokens": null
},
"generated_text": " form\n\nThe test request form is used to request"
}
]
import pytest
@pytest.fixture(scope="module")
def flash_gemma_handle(launcher):
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_gemma(flash_gemma_handle):
await flash_gemma_handle.health(300)
return flash_gemma_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot):
response = await flash_gemma.generate(
"Test request", max_new_tokens=10, decoder_input_details=True
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
response = await flash_gemma.generate(
"Test request",
max_new_tokens=10,
repetition_penalty=1.2,
return_full_text=True,
stop_sequences=["test"],
temperature=0.5,
top_p=0.9,
top_k=10,
truncate=5,
typical_p=0.9,
watermark=True,
decoder_input_details=True,
seed=0,
)
assert response.details.generated_tokens == 10
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
responses = await generate_load(flash_gemma, "Test request", max_new_tokens=10, n=4)
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses == response_snapshot
......@@ -52,6 +52,9 @@ try:
from text_generation_server.models.flash_llama import (
FlashLlama,
)
from text_generation_server.models.flash_gemma import (
FlashGemma,
)
from text_generation_server.models.flash_santacoder import (
FlashSantacoderSharded,
)
......@@ -312,6 +315,28 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type == "gemma":
if FLASH_ATTENTION:
return FlashGemma(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
use_medusa=use_medusa,
)
elif sharded:
raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Golden Gate")
)
else:
return CausalLM(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded:
......
import torch
import torch.distributed
from opentelemetry import trace
from typing import Optional
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
GemmaTokenizerFast,
FlashGemmaForCausalLM,
GemmaConfig,
)
from text_generation_server.utils import (
initialize_torch_distributed,
weight_files,
Weights,
)
tracer = trace.get_tracer(__name__)
class FlashGemma(FlashCausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
use_medusa: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype
else:
raise NotImplementedError("FlashGemma is only available on GPU")
tokenizer = GemmaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
use_fast=True,
from_slow=False,
)
config = GemmaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
config.quantize = quantize
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights(filenames, device, dtype, process_group=self.process_group)
if config.quantize in ["gptq", "awq"]:
weights._set_gptq_params(model_id, revision)
model = FlashGemmaForCausalLM(config, weights)
if use_medusa:
from text_generation_server.utils.medusa import MedusaModel
from huggingface_hub import hf_hub_download
import json
import os
from pathlib import Path
is_local_model = (
Path(use_medusa).exists() and Path(use_medusa).is_dir()
) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None
if not is_local_model:
medusa_config = hf_hub_download(
use_medusa, revision=revision, filename="config.json"
)
medusa_head = hf_hub_download(
use_medusa, revision=revision, filename="medusa_lm_head.pt"
)
else:
medusa_config = str(Path(use_medusa) / "config.json")
medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt")
with open(medusa_config, "r") as f:
config = json.load(f)
medusa_sf = medusa_head[: -len(".pt")] + ".safetensors"
weights = Weights(
[medusa_sf], device, dtype, process_group=self.process_group
)
lm_head = model.lm_head
model.lm_head = MedusaModel(config, weights, lm_head)
torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
num_kv_heads=model.model.num_key_value_heads,
head_size=model.model.head_size,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment