Unverified Commit 0fb70683 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] use proper gemma class and config in lumina2 tests. (#10828)

use proper gemma class and config in lumina2 tests.
parent f8b54cf0
......@@ -2,7 +2,7 @@ import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
from diffusers import (
AutoencoderKL,
......@@ -81,15 +81,16 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
torch.manual_seed(0)
config = GemmaConfig(
head_dim=2,
config = Gemma2Config(
head_dim=4,
hidden_size=8,
intermediate_size=37,
num_attention_heads=4,
intermediate_size=8,
num_attention_heads=2,
num_hidden_layers=2,
num_key_value_heads=4,
num_key_value_heads=2,
sliding_window=2,
)
text_encoder = GemmaForCausalLM(config)
text_encoder = Gemma2Model(config)
components = {
"transformer": transformer.eval(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment