Unverified Commit fc9c3153 authored by Daniël de Kok's avatar Daniël de Kok Committed by GitHub
Browse files

Add pytest release marker (#2114)

* Add pytest release marker

Annotate a test with `@pytest.mark.release` and it only gets run
with `pytest integration-tests --release`.

* Mark many models as `release` to speed up CI
parent e563983d
...@@ -156,6 +156,8 @@ jobs: ...@@ -156,6 +156,8 @@ jobs:
needs: build-and-push needs: build-and-push
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -180,4 +182,4 @@ jobs: ...@@ -180,4 +182,4 @@ jobs:
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HF_TOKEN=${{ secrets.HF_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests pytest -s -vv integration-tests ${PYTEST_FLAGS}
...@@ -37,6 +37,26 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data") ...@@ -37,6 +37,26 @@ DOCKER_VOLUME = os.getenv("DOCKER_VOLUME", "/data")
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES") DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")
def pytest_addoption(parser):
parser.addoption(
"--release", action="store_true", default=False, help="run release tests"
)
def pytest_configure(config):
config.addinivalue_line("markers", "release: mark test as a release-only test")
def pytest_collection_modifyitems(config, items):
if config.getoption("--release"):
# --release given in cli: do not skip release tests
return
skip_release = pytest.mark.skip(reason="need --release option to run")
for item in items:
if "release" in item.keywords:
item.add_marker(skip_release)
class ResponseComparator(JSONSnapshotExtension): class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2 rtol = 0.2
ignore_logprob = False ignore_logprob = False
......
...@@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle): ...@@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle):
return bloom_560_handle.client return bloom_560_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m(bloom_560, response_snapshot): async def test_bloom_560m(bloom_560, response_snapshot):
response = await bloom_560.generate( response = await bloom_560.generate(
...@@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot): ...@@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_all_params(bloom_560, response_snapshot): async def test_bloom_560m_all_params(bloom_560, response_snapshot):
response = await bloom_560.generate( response = await bloom_560.generate(
...@@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot): ...@@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot): async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle): ...@@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
return bloom_560m_sharded_handle.client return bloom_560m_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
response = await bloom_560m_sharded.generate( response = await bloom_560m_sharded.generate(
...@@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot): ...@@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bloom_560m_sharded_load( async def test_bloom_560m_sharded_load(
bloom_560m_sharded, generate_load, response_snapshot bloom_560m_sharded, generate_load, response_snapshot
......
...@@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle): ...@@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
# method for it. Instead, we use the `requests` library to make the HTTP request directly. # method for it. Instead, we use the `requests` library to make the HTTP request directly.
@pytest.mark.release
def test_flash_llama_completion_single_prompt( def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):
...@@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt( ...@@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot): def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post( response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions", f"{flash_llama_completion.base_url}/v1/completions",
...@@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn ...@@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream( async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot flash_llama_completion, response_snapshot
): ):
......
...@@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle): ...@@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
return flash_llama_awq_handle.client return flash_llama_awq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq(flash_llama_awq, response_snapshot): async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
...@@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): ...@@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate( response = await flash_llama_awq.generate(
...@@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): ...@@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): ...@@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
return flash_llama_awq_handle_sharded.client return flash_llama_awq_handle_sharded.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate( response = await flash_llama_awq_sharded.generate(
...@@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho ...@@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_llama_awq_load_sharded( async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot flash_llama_awq_sharded, generate_load, response_snapshot
......
...@@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle): ...@@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
return flash_falcon_handle.client return flash_falcon_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot): async def test_flash_falcon(flash_falcon, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot): ...@@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot): async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
...@@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot): ...@@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot): async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle): ...@@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client return flash_gemma_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot): async def test_flash_gemma(flash_gemma, response_snapshot):
...@@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot): async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
...@@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot): ...@@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot): async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle): ...@@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
return flash_gemma_gptq_handle.client return flash_gemma_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot): async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
...@@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh ...@@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_all_params( async def test_flash_gemma_gptq_all_params(
...@@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params( ...@@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_gemma_gptq_load( async def test_flash_gemma_gptq_load(
......
...@@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle): ...@@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
return flash_gpt2_handle.client return flash_gpt2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot): async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate( response = await flash_gpt2.generate(
...@@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot): ...@@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot): async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle): ...@@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
return flash_llama_exl2_handle.client return flash_llama_exl2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot): async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
...@@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh ...@@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_all_params( async def test_flash_llama_exl2_all_params(
...@@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params( ...@@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params(
assert response == ignore_logprob_response_snapshot assert response == ignore_logprob_response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_exl2_load( async def test_flash_llama_exl2_load(
......
...@@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle): ...@@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
return flash_llama_gptq_handle.client return flash_llama_gptq_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
...@@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
...@@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot): ...@@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_load( async def test_flash_llama_gptq_load(
......
...@@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle): ...@@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
return flash_llama_gptq_marlin_handle.client return flash_llama_gptq_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot): async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho ...@@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params( async def test_flash_llama_gptq_marlin_all_params(
...@@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params( ...@@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_gptq_marlin_load( async def test_flash_llama_gptq_marlin_load(
......
...@@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle): ...@@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
return flash_llama_marlin_handle.client return flash_llama_marlin_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot): ...@@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot): async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
...@@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh ...@@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_llama_marlin_load( async def test_flash_llama_marlin_load(
......
...@@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle): ...@@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client return flash_neox_handle.client
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot): async def test_flash_neox(flash_neox, response_snapshot):
...@@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot): ...@@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle): ...@@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
return flash_neox_sharded_handle.client return flash_neox_sharded_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox(flash_neox_sharded, response_snapshot): async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate( response = await flash_neox_sharded.generate(
...@@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot): ...@@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot): async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
responses = await generate_load( responses = await generate_load(
......
...@@ -34,6 +34,7 @@ def get_cow_beach(): ...@@ -34,6 +34,7 @@ def get_cow_beach():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}" return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
...@@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot): ...@@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot): async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
......
...@@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle): ...@@ -13,6 +13,7 @@ async def flash_phi(flash_phi_handle):
return flash_phi_handle.client return flash_phi_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi(flash_phi, response_snapshot): async def test_flash_phi(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
...@@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_phi(flash_phi, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi_all_params(flash_phi, response_snapshot): async def test_flash_phi_all_params(flash_phi, response_snapshot):
response = await flash_phi.generate( response = await flash_phi.generate(
...@@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot): ...@@ -47,6 +49,7 @@ async def test_flash_phi_all_params(flash_phi, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_phi_load(flash_phi, generate_load, response_snapshot): async def test_flash_phi_load(flash_phi, generate_load, response_snapshot):
responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_phi, "Test request", max_new_tokens=10, n=4)
......
...@@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle): ...@@ -13,6 +13,7 @@ async def flash_qwen2(flash_qwen2_handle):
return flash_qwen2_handle.client return flash_qwen2_handle.client
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2(flash_qwen2, response_snapshot): async def test_flash_qwen2(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate( response = await flash_qwen2.generate(
...@@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot): ...@@ -24,6 +25,7 @@ async def test_flash_qwen2(flash_qwen2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
response = await flash_qwen2.generate( response = await flash_qwen2.generate(
...@@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot): ...@@ -46,6 +48,7 @@ async def test_flash_qwen2_all_params(flash_qwen2, response_snapshot):
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot): async def test_flash_qwen2_load(flash_qwen2, generate_load, response_snapshot):
responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4) responses = await generate_load(flash_qwen2, "Test request", max_new_tokens=10, n=4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment