Unverified Commit 0fb70683 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] use proper gemma class and config in lumina2 tests. (#10828)

use proper gemma class and config in lumina2 tests.
parent f8b54cf0
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM from transformers import AutoTokenizer, Gemma2Config, Gemma2Model
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -81,15 +81,16 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester ...@@ -81,15 +81,16 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
torch.manual_seed(0) torch.manual_seed(0)
config = GemmaConfig( config = Gemma2Config(
head_dim=2, head_dim=4,
hidden_size=8, hidden_size=8,
intermediate_size=37, intermediate_size=8,
num_attention_heads=4, num_attention_heads=2,
num_hidden_layers=2, num_hidden_layers=2,
num_key_value_heads=4, num_key_value_heads=2,
sliding_window=2,
) )
text_encoder = GemmaForCausalLM(config) text_encoder = Gemma2Model(config)
components = { components = {
"transformer": transformer.eval(), "transformer": transformer.eval(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment