Commit be0967c1 authored by zhuwenwen's avatar zhuwenwen
Browse files

update tests

parent e7c1b7f3
...@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing ...@@ -29,7 +29,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 112] HEAD_SIZES = [64, 112]
BLOCK_SIZES = [16, 32] BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True] USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"] KV_CACHE_DTYPE = ["auto", "fp8"] if not is_hip() else ["auto"]
SEEDS = [0] SEEDS = [0]
CUDA_DEVICES = ['cuda:0'] CUDA_DEVICES = ['cuda:0']
BLOCKSPARSE_LOCAL_BLOCKS = [16] BLOCKSPARSE_LOCAL_BLOCKS = [16]
......
...@@ -379,37 +379,37 @@ def test_swap_blocks( ...@@ -379,37 +379,37 @@ def test_swap_blocks(
dist_value_caches[0][dst].cpu()) dist_value_caches[0][dst].cpu())
@pytest.mark.parametrize("num_heads", NUM_HEADS) # @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) # @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) # @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) # @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES) # @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) # @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) # @pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode() # @torch.inference_mode()
def test_fp8_e4m3_conversion( # def test_fp8_e4m3_conversion(
num_heads: int, # num_heads: int,
head_size: int, # head_size: int,
block_size: int, # block_size: int,
num_blocks: int, # num_blocks: int,
dtype: torch.dtype, # dtype: torch.dtype,
seed: int, # seed: int,
device: str, # device: str,
) -> None: # ) -> None:
random.seed(seed) # random.seed(seed)
torch.random.manual_seed(seed) # torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) # torch.cuda.manual_seed(seed)
low = -224.0 # low = -224.0
high = 224.0 # high = 224.0
shape = (num_blocks, num_heads, head_size, block_size) # shape = (num_blocks, num_heads, head_size, block_size)
cache = torch.empty(shape, dtype=dtype, device=device) # cache = torch.empty(shape, dtype=dtype, device=device)
cache.uniform_(low, high) # cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) # cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
ops.convert_fp8(cache_fp8, cache) # ops.convert_fp8(cache_fp8, cache)
converted_cache = torch.empty_like(cache) # converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8) # ops.convert_fp8(converted_cache, cache_fp8)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) # assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment