import math def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local, max_splits): """ Determines the optimal number of splits for maximizing GPU occupancy while balancing memory efficiency. Parameters: - total_mblocks (int): Total number of m_blocks. - num_SMs (int): Number of Streaming Multiprocessors (SMs) in the GPU. - num_n_blocks (int): Number of n_blocks. - num_m_blocks (int): Number of m_blocks. - size_one_kv_head (int): Size of one KV head in bytes. - is_causal_or_local (bool): Indicates whether the operation is causal or local. - max_splits (int): Maximum number of allowed splits. Returns: - int: The optimal number of splits. """ # If we have enough m_blocks to almost fill the SMs, prefer 1 split unless memory constraints apply. if total_mblocks >= 0.8 * num_SMs: size_l2 = 50 * 1024 * 1024 # L2 cache size assumption (50MB) # Only split if each KV head is too large for L2 and there are enough m_blocks if size_one_kv_head > size_l2 and num_m_blocks >= num_SMs * 2 and not is_causal_or_local: return min((size_one_kv_head + size_l2 - 1) // size_l2, max_splits) else: return 1 # If num_n_blocks is too small, we don't split if num_n_blocks <= 4: return 1 # Limit max_splits to a reasonable range max_splits = min(max_splits, num_SMs, num_n_blocks) max_efficiency = 0.0 efficiency = [] # Compute efficiency for different splits for num_splits in range(1, max_splits + 1): n_waves = (total_mblocks * num_splits) / num_SMs eff = n_waves / math.ceil(n_waves) # Track max efficiency if eff > max_efficiency: max_efficiency = eff efficiency.append(eff) # Find the smallest number of splits that achieves at least 85% of max efficiency for num_splits in range(1, max_splits + 1): if efficiency[num_splits - 1] >= 0.85 * max_efficiency: return num_splits return 1