Commit 9231098f authored by Daniël de Kok's avatar Daniël de Kok Committed by Daniël de Kok
Browse files

Fix (flash) Gemma prefix and enable tests

parent d32e33bd
......@@ -3,7 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def flash_gemma_handle(launcher):
with launcher("gg-hf/gemma-2b", num_shard=1) as handle:
with launcher("google/gemma-2b", num_shard=1) as handle:
yield handle
......@@ -13,7 +13,6 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot):
......@@ -25,7 +24,6 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
......@@ -49,7 +47,6 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
......
......@@ -423,7 +423,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
super().__init__()
embed_norm = config.hidden_size**0.5
if prefix is None:
if not prefix:
prefix = "model"
else:
prefix = f"{prefix}.model"
......
......@@ -57,7 +57,7 @@ class FlashGemma(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
# TODO hardcoded
prefix = "language_model"
prefix = ""
model = FlashGemmaForCausalLM(prefix, config, weights, causal=True)
torch.distributed.barrier(group=self.process_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment