Unverified Commit c2488003 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[kernel] skip tests of flash_attn and triton when they are not available (#1798)

parent e34e850a
......@@ -61,7 +61,7 @@ class GeminiManager:
self._comp_cuda_demand_time = 0
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of statefuil tensor according to the information provided
""" Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
......
import torch
import pytest
import torch
from einops import rearrange
from colossalai.kernel.cuda_native.flash_attention import flash_attention, triton_flash_attention, TRITON_AVALIABLE
from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_TRITON, TRITON_AVALIABLE
if HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.flash_attention import flash_attention
if HAS_TRITON:
from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
......@@ -14,7 +21,8 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
ref_out = torch.matmul(p, v)
return ref_out
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
......@@ -23,7 +31,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
......@@ -51,6 +59,7 @@ def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
raise TypeError("Error type not match!")
@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="triton is not available")
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 16, 8)])
def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
......@@ -59,21 +68,22 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
sm_scale = 0.3
dout = torch.randn_like(q)
# reference implementation
ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# flash implementation
q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
tri_out = flash_attention(q, k, v, sm_scale, Z, N_CTX)
dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
tri_out.backward(dout, retain_graph=True)
tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z), (tri_out, tri_dq, tri_dk, tri_dv))
tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
(tri_out, tri_dq, tri_dk, tri_dv))
# compare
assert torch.allclose(ref_out, tri_out, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment