Unverified Commit 3123f151 authored by Tao He's avatar Tao He Committed by GitHub
Browse files

Fixes the incorrect argument in the prefix-prefill test cases (#3246)

parent 413366e9
...@@ -18,7 +18,7 @@ CUDA_DEVICES = [ ...@@ -18,7 +18,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_HEADS) @pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
...@@ -35,6 +35,13 @@ def test_contexted_kv_attention( ...@@ -35,6 +35,13 @@ def test_contexted_kv_attention(
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
torch.set_default_device(device) torch.set_default_device(device)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
#
# see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
torch.cuda.set_device(device)
MAX_SEQ_LEN = 1024 MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024 MAX_CTX_LEN = 1024
BS = 10 BS = 10
...@@ -172,5 +179,5 @@ def test_contexted_kv_attention( ...@@ -172,5 +179,5 @@ def test_contexted_kv_attention(
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.squeeze(0, 2) output_ref = output_ref.reshape(output.shape)
assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment