Commit 6ac8f906 authored by wooway777's avatar wooway777
Browse files

issue/919 - ninetoothed flash attention

parent 47843aa6
......@@ -5,6 +5,7 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
......
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool);
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal);
} // namespace infinicore::op
......@@ -10,6 +10,7 @@
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/embedding.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
......
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;
__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal);
__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);
__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream);
__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc);
#endif
from .causal_softmax import causal_softmax
from .embedding import embedding
from .flash_attention import flash_attention
from .linear import linear
from .random_sample import random_sample
from .rms_norm import rms_norm
......@@ -9,12 +10,13 @@ from .swiglu import swiglu
__all__ = [
"causal_softmax",
"embedding",
"flash_attention",
"linear",
"random_sample",
"rms_norm",
"RopeAlgo",
"rope",
"silu",
"swiglu",
"linear",
"embedding",
"rope",
"RopeAlgo",
]
import math
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def flash_attention(
query,
key,
value,
total_kv_len,
attn_mask=None,
dropout_p=0,
is_causal=False,
scale=None,
enable_gqa=False,
):
assert attn_mask is None and dropout_p == 0 and not enable_gqa
emb_dim = query.shape[-1]
if scale is None:
scale = 1 / math.sqrt(emb_dim)
return Tensor(
_infinicore.flash_attention(
query._underlying,
key._underlying,
value._underlying,
total_kv_len._underlying,
scale,
is_causal,
)
)
#include "infinicore/ops/flash_attention.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention);
FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, total_kv_len, scale, is_causal);
}
void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal);
}
Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
Shape shape = q->shape();
int idx = shape.size() - 1;
shape[idx] = v->shape()[idx];
auto out = Tensor::empty(shape, q->dtype(), q->device());
flash_attention_(out, q, k, v, total_kv_len, scale, is_causal);
return out;
}
void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/flash_attention.hpp"
#include <infiniop.h>
namespace infinicore::op::flash_attention_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, q, k, v, total_kv_len;
float scale;
bool is_causal;
};
void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) {
size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, FlashAttention,
seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal);
INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor);
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(total_kv_len), scale, is_causal};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopFlashAttention(
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FlashAttention, &plan, &run, &cleanup);
} // namespace infinicore::op::flash_attention_impl::infiniop
......@@ -7,6 +7,7 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
......@@ -29,13 +30,14 @@ inline void bind(py::module &m) {
bind_add_rms_norm(m);
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
bind_flash_attention(m);
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_paged_attention(m);
bind_paged_attention_prefill(m);
bind_paged_caching(m);
bind_random_sample(m);
bind_rearrange(m);
bind_rms_norm(m);
bind_silu(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/flash_attention.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_flash_attention(py::module &m) {
m.def("flash_attention",
&op::flash_attention,
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("total_kv_len"),
py::arg("scale"),
py::arg("is_causal"));
}
} // namespace infinicore::ops
import ninetoothed
from . import flash_attention
from .flash_attention import CausalVariant
import infiniop.ninetoothed.build
import torch
def build():
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
for i in range(device_count):
device_name = torch.cuda.get_device_name(i).lower()
if "metax" in device_name:
return
with_kv_cache_values = (0,)
emb_dim_values = (16, 32, 64, 128, 256)
is_causal_values = (0, 1)
with_attn_mask_values = (0,)
causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT)
dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32)
block_size_m_values = (256,)
block_size_n_values = (64,)
constexpr_param_grid = {
"with_kv_cache": with_kv_cache_values,
"emb_dim": emb_dim_values,
"is_causal": is_causal_values,
"with_attn_mask": with_attn_mask_values,
"causal_variant": causal_variant_values,
"dtype": dtype_values,
"block_size_m": block_size_m_values,
"block_size_n": block_size_n_values,
}
infiniop.ninetoothed.build.build(
flash_attention.premake,
constexpr_param_grid,
caller="cuda",
op_name="flash_attention",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__
#define __FLASH_ATTENTION_DESCRIPTOR_H__
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/flash_attention.h"
#include "../../../ninetoothed/utils.h"
namespace op::flash_attention::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(infiniopHandle_t handle,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) : InfiniopDescriptor{handle->device, handle->device_id},
_query_shape{q_desc->shape()},
_query_strides{q_desc->strides()},
_key_shape{k_desc->shape()},
_key_strides{k_desc->strides()},
_value_shape{v_desc->shape()},
_value_strides{v_desc->strides()},
_total_kv_shape{total_kv_len->shape()},
_total_kv_strides{total_kv_len->strides()},
_output_strides{out_desc->strides()},
_dtype{q_desc->dtype()},
_scale{scale},
_is_causal{is_causal} {
}
~Descriptor() = default;
size_t get_workspace_size() const {
return 0;
}
infiniStatus_t calculate(void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) const {
uint64_t empty_shape[4];
int64_t empty_strides[4];
auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}};
NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides};
NineToothedTensor is_causal;
NineToothedTensor scale{const_cast<double *>(&_scale), nullptr, nullptr};
auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}};
NineToothedTensor with_attn_mask;
NineToothedTensor causal_variant;
const auto with_kv_cache_{0};
const auto emb_dim_{_query_shape[3]};
const auto is_causal_{_is_causal};
const auto with_attn_mask_{0};
const auto causal_variant_{2};
const auto dtype_{_dtype};
constexpr auto block_size_m_{256};
constexpr auto block_size_n_{64};
if (launch_flash_attention(stream,
query,
key,
value,
total_kv_length,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache_,
emb_dim_,
is_causal_,
with_attn_mask_,
causal_variant_,
dtype_,
block_size_m_,
block_size_n_)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
static infiniStatus_t create(infiniopHandle_t handle,
Descriptor **desc,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) {
*desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal};
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> _query_shape;
std::vector<Stride> _query_strides;
std::vector<Size> _key_shape;
std::vector<Stride> _key_strides;
std::vector<Size> _value_shape;
std::vector<Stride> _value_strides;
std::vector<Size> _total_kv_shape;
std::vector<Stride> _total_kv_strides;
std::vector<Stride> _output_strides;
infiniDtype_t _dtype;
double _scale;
char _is_causal;
};
} // namespace op::flash_attention::ninetoothed
#endif // __FLASH_ATTENTION_DESCRIPTOR_H__
import enum
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
BLOCK_SIZE_M = ninetoothed.block_size()
BLOCK_SIZE_N = ninetoothed.block_size()
class CausalVariant(enum.IntEnum):
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
UPPER_LEFT = enum.auto()
LOWER_RIGHT = enum.auto()
def arrangement(
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache,
block_size_m=None,
block_size_n=None,
):
def arrange_query_or_output(input):
arranged = input.tile((1, 1, block_size_m, -1)).tile(
(1, query.shape[-3] // key.shape[-3], 1, 1)
)
arranged.dtype = arranged.dtype.squeeze((0, 2, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
def arrange_key_or_value(input):
arranged = (
input.tile((1, 1, block_size_n, -1))
.tile((1, 1, -1, -1))
.expand((-1, -1, query_arranged.shape[-2], -1))
)
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
def arrange_total_kv_len(input, shape):
arranged = input.tile((1,))
arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape)
return arranged
def arrange_present_key_or_present_value(input):
arranged = input.tile((1, 1, block_size_m, block_size_n))
arranged.dtype = arranged.dtype.squeeze((0, 1))
return arranged
def arrange_attn_mask(input):
arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1))
arranged.dtype = arranged.dtype.squeeze((0, 1, 2))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
if block_size_n is None:
block_size_n = BLOCK_SIZE_N
query_arranged = arrange_query_or_output(query)
key_arranged = arrange_key_or_value(key)
value_arranged = arrange_key_or_value(value)
total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape)
present_key_arranged = arrange_present_key_or_present_value(present_key)
present_value_arranged = arrange_present_key_or_present_value(present_value)
present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot)
present_value_slot_arranged = arrange_present_key_or_present_value(
present_value_slot
)
attn_mask_arranged = arrange_attn_mask(attn_mask)
is_causal_arranged = is_causal
scale_arranged = scale
output_arranged = arrange_query_or_output(output)
with_attn_mask_arranged = with_attn_mask
causal_variant_arranged = causal_variant
if with_kv_cache:
return (
query_arranged,
key_arranged,
value_arranged,
total_kv_len_arranged,
present_key_arranged,
present_value_arranged,
present_key_slot_arranged,
present_value_slot_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)
return (
query_arranged,
key_arranged,
value_arranged,
total_kv_len_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)
def application_with_kv_cache(
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
present_key_slot = present_key # noqa: F841
present_value_slot = present_value # noqa: F841
application_without_kv_cache(
query,
key,
value,
total_kv_len,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)
def application_without_kv_cache(
query,
key,
value,
total_kv_len,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
actual_kv_len = total_kv_len[0]
for i in range(query.shape[0]):
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32)
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
for j in range(-(-actual_kv_len // key.dtype.shape[0])):
qk = ntl.dot(query_i, ntl.trans(key[j]))
key_pos = key[j].offsets(-2)
qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf"))
if with_attn_mask:
qk += attn_mask[j]
if is_causal:
query_pos = query[i].offsets(-2)
if causal_variant == 2: # CausalVariant.LOWER_RIGHT:
mask = (
query_pos[:, None] + actual_kv_len - query.source.shape[-2]
>= key_pos[None, :]
)
else:
mask = query_pos[:, None] >= key_pos[None, :]
qk = ntl.where(mask, qk, float("-inf"))
next_max = ntl.maximum(max, ntl.max(qk, 1))
stable_qk = ntl.exp2(qk - next_max[:, None])
alpha = ntl.exp2(max - next_max)
acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j])
max = next_max
lse = lse * alpha + ntl.sum(stable_qk, 1)
acc /= lse[:, None]
output[i] = acc # noqa: F841
def premake(
with_kv_cache,
emb_dim=None,
is_causal=None,
with_attn_mask=None,
causal_variant=None,
dtype=None,
block_size_m=None,
block_size_n=None,
):
arrangement_ = functools.partial(
arrangement,
with_kv_cache=with_kv_cache,
block_size_m=block_size_m,
block_size_n=block_size_n,
)
query, key, value, attn_mask, output = (
Tensor(
4,
dtype=dtype,
shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}),
)
for _ in range(5)
)
total_kv_len = Tensor(1, dtype=ninetoothed.int32)
present_key, present_value, present_key_slot, present_value_slot = (
Tensor(4, dtype=dtype) for _ in range(4)
)
scale = Tensor(0, dtype=ninetoothed.float64)
is_causal = Tensor(0, constexpr=True, value=is_causal)
with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask)
causal_variant = Tensor(0, constexpr=True, value=causal_variant)
if emb_dim is not None:
for tensor in (query, key, value, attn_mask, output):
tensor.shape = tensor.shape[:-1] + (emb_dim,)
if with_kv_cache:
application = application_with_kv_cache
else:
application = application_without_kv_cache
tensors = (
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)
return arrangement_, application, tensors
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/flash_attention.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
#include "ninetoothed/descriptor.h"
#endif
#endif
__C infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::flash_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
q_desc, \
k_desc, \
v_desc, \
total_kv_len, \
scale, \
is_causal);
switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size) {
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET_SIZE
}
__C infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream);
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import (
BaseOperatorTest,
TensorSpec,
TensorInitializer,
TestCase,
GenericTestRunner,
)
# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal)
# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim)
_TEST_CASES_DATA = [
((1, 1, 2, 16), (1, 1, 8, 16), (1, 1, 8, 16), None, 0.0, False),
((1, 2, 128, 16), (1, 2, 256, 16), (1, 2, 256, 16), None, 0.0, False),
((1, 1, 4, 32), (1, 1, 32, 32), (1, 1, 32, 32), None, 0.0, True),
((1, 8, 256, 16), (1, 8, 512, 16), (1, 8, 512, 16), None, 0.0, True),
((1, 8, 4, 16), (1, 8, 64, 16), (1, 8, 64, 16), None, 0.0, False),
((8, 28, 256, 128), (8, 28, 512, 128), (8, 28, 512, 128), None, 0.0, True),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-3, "rtol": 1e-3},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
import random
cases = []
for q_shape, k_shape, v_shape, attn_mask, dropout_p, is_causal in _TEST_CASES_DATA:
for dtype in _TENSOR_DTYPES:
tol = _TOLERANCE_MAP[dtype]
q_spec = TensorSpec.from_tensor(q_shape, None, dtype)
k_spec = TensorSpec.from_tensor(k_shape, None, dtype)
v_spec = TensorSpec.from_tensor(v_shape, None, dtype)
len_shape = (q_shape[0],)
total_len = random.randint(1, k_shape[2])
total_kv_len_spec = TensorSpec.from_tensor(
len_shape,
None,
infinicore.int64,
init_mode=TensorInitializer.RANDINT,
low=total_len,
high=total_len + 1,
)
kwargs = {
"attn_mask": attn_mask,
"dropout_p": dropout_p,
"is_causal": is_causal,
}
# remove None keys
kwargs = {k: v for k, v in kwargs.items() if v is not None}
cases.append(
TestCase(
inputs=[q_spec, k_spec, v_spec, total_kv_len_spec, total_len],
kwargs=kwargs,
output_spec=None,
comparison_target=None,
tolerance=tol,
description="Flash Attention",
)
)
return cases
def torch_flash_attn(q, k, v, total_kv_len, cheat, **kwargs):
k_slice = k[:, :, :cheat, :]
v_slice = v[:, :, :cheat, :]
return torch.nn.functional.scaled_dot_product_attention(
q, k_slice, v_slice, **kwargs
)
def infini_flash_attn(q, k, v, total_kv_len, cheat, **kwargs):
return infinicore.nn.functional.flash_attention(q, k, v, total_kv_len, **kwargs)
class OpTest(BaseOperatorTest):
"""ScaledDotProductAttention operator test with simplified implementation"""
def __init__(self):
super().__init__("ScaledDotProductAttention")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
return torch_flash_attn(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
return infini_flash_attn(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment