# SPDX-License-Identifier: MIT import pytest import torch from aiter.ops.triton.fla.sglang.chunk_delta_h import ( chunk_gated_delta_rule_fwd_h, prepare_chunk_indices, ) from aiter.ops.triton.fla.sglang.chunk_o import chunk_fwd_o from op_tests.triton_tests.utils.chunk_delta_h_sglang_ref import ( chunk_gated_delta_rule_fwd_h as chunk_gated_delta_rule_fwd_h_ref, ) from op_tests.triton_tests.utils.chunk_o_sglang_ref import ( chunk_fwd_o as chunk_fwd_o_ref, ) @pytest.mark.parametrize("use_varlen", [False, True]) @pytest.mark.parametrize("state_index_mode", ["identity", "random"]) def test_chunk_gated_fwd_h_and_o_against_ref(use_varlen: bool, state_index_mode: str): if not torch.cuda.is_available(): pytest.skip("CUDA is required") torch.manual_seed(42) device = torch.device("cuda") bsz = 1 total_t = 130 h_dim = 16 hg = 8 k_dim = 128 v_dim = 128 chunk_size = 64 q = torch.randn((bsz, total_t, hg, k_dim), device=device, dtype=torch.float16) * 0.2 k = torch.randn((bsz, total_t, hg, k_dim), device=device, dtype=torch.float16) * 0.2 w = torch.randn((bsz, total_t, h_dim, k_dim), device=device, dtype=torch.float16) * 0.2 u = torch.randn((bsz, total_t, h_dim, v_dim), device=device, dtype=torch.float16) * 0.2 g = torch.randn((bsz, total_t, h_dim), device=device, dtype=torch.float32) * 0.05 if use_varlen: cu_seqlens = torch.tensor([0, 37, 93, total_t], device=device, dtype=torch.long) n_seq = len(cu_seqlens) - 1 else: cu_seqlens = None n_seq = bsz if state_index_mode == "none": state_rows = n_seq initial_state_indices_cur = None initial_state_indices_ref = torch.arange(n_seq, device=device, dtype=torch.int32) elif state_index_mode == "random": state_rows = n_seq + 3 initial_state_indices_cur = torch.randperm(state_rows, device=device, dtype=torch.int64)[:n_seq].to(torch.int32) initial_state_indices_ref = initial_state_indices_cur else: state_rows = n_seq initial_state_indices_cur = torch.arange(n_seq, device=device, dtype=torch.int32) initial_state_indices_ref = initial_state_indices_cur initial_state_cur = torch.randn((state_rows, h_dim, v_dim, k_dim), device=device, dtype=torch.float32) * 0.02 initial_state_ref = initial_state_cur.clone() chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None h_cur, v_new_cur = chunk_gated_delta_rule_fwd_h( k=k, w=w, u=u, g=g, initial_state=initial_state_cur, initial_state_indices=initial_state_indices_cur, output_final_state=True, chunk_size=chunk_size, save_new_value=True, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, use_exp2=False, transpose_state_layout=True, ) h_ref, v_new_ref = chunk_gated_delta_rule_fwd_h_ref( k=k, w=w, u=u, g=g, gk=None, initial_state=initial_state_ref, initial_state_indices=initial_state_indices_ref, save_new_value=True, cu_seqlens=cu_seqlens, ) scale = k_dim ** -0.5 o_cur = chunk_fwd_o( q=q, k=k, v=v_new_cur, h=h_cur, g=g, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices, use_exp2=False, transpose_state_layout=True, ) o_ref = chunk_fwd_o_ref( q=q, k=k, v=v_new_ref, h=h_ref, g=g, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, ) assert torch.equal(h_cur, h_ref) assert torch.equal(o_cur, o_ref)