You need to sign in or sign up before continuing.
Unverified Commit fc9c3153 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add pytest release marker (#2114)

* Add pytest release marker

Annotate a test with `@pytest.mark.release` and it only gets run
with `pytest integration-tests --release`.

* Mark many models as `release` to speed up CI
parent e563983d
...@@ -156,6 +156,8 @@ jobs: ...@@ -156,6 +156,8 @@ jobs:
needs: build-and-push needs: build-and-push
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -180,4 +182,4 @@ jobs: ...@@ -180,4 +182,4 @@ jobs:
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests pytest -s -vv integration-tests ${PYTEST_FLAGS}
...@@ -37,6 +37,26 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") ...@@ -37,6 +37,26 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
def pytest_addoption(parser):
parser.addoption(
"--release", action="store_true", default=False, help="run release tests"
)
def pytest_configure(config):
config.addinivalue_line("markers", "release: mark test as a release-only test")
def pytest_collection_modifyitems(config, items):
if config.getoption("--release"):
# --release given in cli: do not skip release tests
return
skip_release = pytest.mark.skip(reason="need --release option to run")
for item in items:
if "release" in item.keywords:
item.add_marker(skip_release)
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
ignore_logprob = False ignore_logprob = False
......
...@@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle): ...@@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle):
return bloom_560_handle.client return bloom_560_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m(bloom_560, response_snapshot): async def test_bloom_560m(bloom_560, response_snapshot):
response = await bloom_560.generate( response = await bloom_560.generate(
...@@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): ...@@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_all_params(bloom_560, response_snapshot): async def test_bloom_560m_all_params(bloom_560, response_snapshot):
response = await bloom_560.generate( response = await bloom_560.generate(
...@@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): ...@@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot): async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle): ...@@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
return bloom_560m_sharded_handle.client return bloom_560m_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
response = await bloom_560m_sharded.generate( response = await bloom_560m_sharded.generate(
...@@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): ...@@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded_load( async def test_bloom_560m_sharded_load(
bloom_560m_sharded, generate_load, response_snapshot bloom_560m_sharded, generate_load, response_snapshot
......
...@@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle): ...@@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
# method for it. Instead, we use the `requests` library to make the HTTP request directly. # method for it. Instead, we use the `requests` library to make the HTTP request directly.
@pytest.mark.release
def test_flash_llama_completion_single_prompt( def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):
...@@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt( ...@@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post( response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
...@@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn ...@@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream( async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):
......
...@@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle): ...@@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
return flash_llama_awq_handle.client return flash_llama_awq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq(flash_llama_awq, response_snapshot): async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
...@@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): ...@@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
...@@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): ...@@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): ...@@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
return flash_llama_awq_handle_sharded.client return flash_llama_awq_handle_sharded.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate( response = await flash_llama_awq_sharded.generate(
...@@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho ...@@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_load_sharded( async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot flash_llama_awq_sharded, generate_load, response_snapshot
......
...@@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle): ...@@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
return flash_falcon_handle.client return flash_falcon_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot): async def test_flash_falcon(flash_falcon, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): ...@@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot): async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
...@@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): ...@@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle): ...@@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client return flash_gemma_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot): async def test_flash_gemma(flash_gemma, response_snapshot):
...@@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot): async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
...@@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): ...@@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): ...@@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
return flash_gemma_gptq_handle.client return flash_gemma_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
...@@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh ...@@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_all_params( async def test_flash_gemma_gptq_all_params(
...@@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params( ...@@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_load( async def test_flash_gemma_gptq_load(
......
...@@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle): ...@@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
return flash_gpt2_handle.client return flash_gpt2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot): async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate( response = await flash_gpt2.generate(
...@@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot): ...@@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle): ...@@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
return flash_llama_exl2_handle.client return flash_llama_exl2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
...@@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh ...@@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_all_params( async def test_flash_llama_exl2_all_params(
...@@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params( ...@@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_load( async def test_flash_llama_exl2_load(
......
...@@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle): ...@@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
return flash_llama_gptq_handle.client return flash_llama_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
...@@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
...@@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): ...@@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_load( async def test_flash_llama_gptq_load(
......
...@@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): ...@@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
return flash_llama_gptq_marlin_handle.client return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho ...@@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params( async def test_flash_llama_gptq_marlin_all_params(
...@@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params( ...@@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin_load( async def test_flash_llama_gptq_marlin_load(
......
...@@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle): ...@@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
return flash_llama_marlin_handle.client return flash_llama_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): ...@@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot): async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
...@@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh ...@@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin_load( async def test_flash_llama_marlin_load(
......
...@@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle): ...@@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client return flash_neox_handle.client
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot): async def test_flash_neox(flash_neox, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot): ...@@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle): ...@@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
return flash_neox_sharded_handle.client return flash_neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox_sharded, response_snapshot): async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate( response = await flash_neox_sharded.generate(
...@@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): ...@@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -34,6 +34,7 @@ def get_cow_beach(): ...@@ -34,6 +34,7 @@ def get_cow_beach():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
...@@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): ...@@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle): ...@@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle):
return flash_phi_handle.client return flash_phi_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi(flash_phi, response_snapshot): async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
...@@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi_all_params(flash_phi, response_snapshot): async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
...@@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): ...@@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
......
...@@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle): ...@@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
return flash_qwen2_handle.client return flash_qwen2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2(flash_qwen2, response_snapshot): async def test_flash_qwen2(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate( response = await flash_qwen2.generate(
...@@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate( response = await flash_qwen2.generate(
...@@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): ...@@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment