Commit f6496d44 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 support flash_attn lib with aten adaptor

parent 8d99a8f5
#pragma once
#include "../tensor.hpp"
#include <ATen/ATen.h>
namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) {
switch (dtype) {
case DataType::F32:
return at::kFloat;
case DataType::F16:
return at::kHalf;
case DataType::BF16:
return at::kBFloat16;
case DataType::I32:
return at::kInt;
case DataType::I64:
return at::kLong;
default:
throw std::runtime_error("Unsupported dtype for ATen");
}
}
inline at::Device to_at_device(const Device &device) {
if (device.getType() == Device::Type::NVIDIA) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
} else {
throw std::runtime_error("Unsupported device type for ATen");
}
}
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
} // namespace infinicore::adaptor
\ No newline at end of file
#pragma once
#include "aten_adaptor.hpp"
namespace flash {
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x seqlen_q
std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state);
std::vector<at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
const int max_seqlen_q,
const int max_seqlen_k, // max sequence length to choose the kernel
const float p_dropout, // probability to drop
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
std::optional<at::Generator> gen_,
std::optional<at::Tensor> &rng_state);
std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
std::optional<const at::Tensor> &seqlens_k_, // batch_size
std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
std::optional<const at::Tensor> &leftpad_k_, // batch_size
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);
} // namespace flash
\ No newline at end of file
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace infinicore::op {
INFINICORE_GRAPH_OP_CLASS(
MultiheadAttentionVarlen,
Tensor,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
const Tensor &,
int,
int,
std::optional<Tensor>,
float);
Tensor mha_varlen(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale);
void mha_varlen_(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale);
} // namespace infinicore::op
......@@ -52,6 +52,7 @@ from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention
from infinicore.ops.kv_caching import kv_caching
from infinicore.ops.matmul import matmul
from infinicore.ops.mha_varlen import mha_varlen
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention
......@@ -134,6 +135,7 @@ __all__ = [
"from_list",
"from_numpy",
"from_torch",
"mha_varlen",
"paged_caching",
"paged_attention",
"paged_attention_prefill",
......
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def mha_varlen(
q: Tensor,
k: Tensor,
v: Tensor,
cum_seqlens_q: Tensor,
cum_seqlens_k: Tensor,
block_table: Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.mha_varlen(
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)
_infinicore.mha_varlen_(
out._underlying,
q._underlying,
k._underlying,
v._underlying,
cum_seqlens_q._underlying,
cum_seqlens_k._underlying,
block_table._underlying,
max_seqlen_q,
max_seqlen_k,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
return out
#include "infinicore/adaptor/aten_adaptor.hpp"
namespace infinicore::adaptor {
at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
void *data_ptr = (void *)(t->data());
auto sizes = std::vector<int64_t>(
t->shape().begin(),
t->shape().end());
auto strides = t->strides();
auto dtype = to_at_dtype(t->dtype());
auto device = to_at_device(t->device());
auto deleter_ = [](void * /*unused*/) mutable {
};
at::TensorOptions options = at::TensorOptions()
.dtype(dtype)
.device(device)
.requires_grad(false);
return at::from_blob(
data_ptr,
sizes,
strides,
deleter_,
options);
}
} // namespace infinicore::adaptor
\ No newline at end of file
#include "infinicore/ops/mha_varlen.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MultiheadAttentionVarlen);
MultiheadAttentionVarlen::MultiheadAttentionVarlen(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
void MultiheadAttentionVarlen::execute(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MultiheadAttentionVarlen,
out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
Tensor mha_varlen(
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_varlen_(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
return out;
}
void mha_varlen_(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_kv,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
MultiheadAttentionVarlen::execute(out, q, k, v, cum_seqlens_q, cum_seqlens_kv, block_table, max_seqlen_q, max_seqlen_k, alibi_slopes, scale);
}
} // namespace infinicore::op
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
namespace infinicore::op::mha_varlen_impl::flashattn {
struct PlannedMeta {
graph::GraphTensor out, q, k, v, cum_seqlens_q, cum_seqlens_k, block_table;
int max_seqlen_q, max_seqlen_k;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};
void *plan(Tensor out,
const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &cum_seqlens_q,
const Tensor &cum_seqlens_k,
const Tensor &block_table,
int max_seqlen_q,
int max_seqlen_k,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k),
graph::GraphTensor(v),
graph::GraphTensor(cum_seqlens_q),
graph::GraphTensor(cum_seqlens_k),
graph::GraphTensor(block_table),
max_seqlen_q,
max_seqlen_k,
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}
void run(void *planned_meta) {
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k);
auto v = infinicore::adaptor::to_aten_tensor(p->v);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto max_seqlen_q = p->max_seqlen_q;
auto max_seqlen_k = p->max_seqlen_k;
auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt;
auto scale = p->scale;
flash::mha_varlen_fwd(
q,
k,
v,
out,
cu_seqlens_q,
cu_seqlens_kv,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,
max_seqlen_q,
max_seqlen_k,
0.0,
scale,
false,
true,
-1,
-1,
0.0,
false,
std::nullopt);
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MultiheadAttentionVarlen, &plan, &run, &cleanup);
} // namespace infinicore::op::mha_varlen_impl::flashattn
......@@ -13,6 +13,7 @@
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/mha_varlen.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
......@@ -38,6 +39,7 @@ inline void bind(py::module &m) {
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_mha_varlen(m);
bind_paged_attention(m);
bind_paged_attention_prefill(m);
bind_paged_caching(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/mha_varlen.hpp"
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_mha_varlen(Tensor q,
Tensor k,
Tensor v,
Tensor cum_seqlens_q,
Tensor cum_seqlens_k,
Tensor block_table,
int max_seqlen_q,
int max_seqlen_k,
pybind11::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::mha_varlen(
q,
k,
v,
cum_seqlens_q,
cum_seqlens_k,
block_table,
max_seqlen_q,
max_seqlen_k,
alibi_slopes_tensor,
scale);
}
void py_mha_varlen_(Tensor out,
Tensor q,
Tensor k,
Tensor v,
Tensor cum_seqlens_q,
Tensor cum_seqlens_k,
Tensor block_table,
int max_seqlen_q,
int max_seqlen_k,
pybind11::object alibi_slopes,
float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
op::mha_varlen_(
out,
q,
k,
v,
cum_seqlens_q,
cum_seqlens_k,
block_table,
max_seqlen_q,
max_seqlen_k,
alibi_slopes_tensor,
scale);
}
inline void bind_mha_varlen(py::module &m) {
m.def(
"mha_varlen",
&ops::py_mha_varlen,
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("cum_seqlens_q"),
py::arg("cum_seqlens_k"),
py::arg("block_table"),
py::arg("max_seqlen_q"),
py::arg("max_seqlen_k"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Variable-length multi-head attention.)doc");
m.def(
"mha_varlen_",
&ops::py_mha_varlen_,
py::arg("out"),
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("cum_seqlens_q"),
py::arg("cum_seqlens_k"),
py::arg("block_table"),
py::arg("max_seqlen_q"),
py::arg("max_seqlen_k"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place variable-length multi-head attention.)doc");
}
} // namespace infinicore::ops
import os
import sys
import infinicore
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from framework import (
BaseOperatorTest,
GenericTestRunner,
TensorInitializer,
TensorSpec,
TestCase,
)
# Test Cases: (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds)
_TEST_CASES_DATA = [
(1, 1, 1, 128, 256, 16, 1),
(1, 4, 4, 128, 256, 16, 4),
(2, 8, 8, 128, 256, 16, 2),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.bfloat16: {"atol": 2e-2, "rtol": 2e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16]
class SimpleCacheManager:
def __init__(self, num_blocks, block_size):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.request_to_blocks = {}
self.request_to_len = {}
def allocate_slots(self, request_id, num_new_tokens):
if request_id not in self.request_to_len:
self.request_to_len[request_id] = 0
self.request_to_blocks[request_id] = []
start_pos = self.request_to_len[request_id]
new_total_len = start_pos + num_new_tokens
needed_blocks = (new_total_len + self.block_size - 1) // self.block_size
added_blocks = needed_blocks - len(self.request_to_blocks[request_id])
for _ in range(added_blocks):
self.request_to_blocks[request_id].append(self.free_blocks.pop(0))
self.request_to_len[request_id] = new_total_len
return self.request_to_blocks[request_id], new_total_len
def parse_test_cases():
test_cases = []
for (
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_step_len,
num_rounds,
) in _TEST_CASES_DATA:
scale = head_size**-0.5
num_blocks = 512
manager = SimpleCacheManager(num_blocks, block_size)
kv_lens = torch.zeros(num_seqs, dtype=torch.int32)
persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size))
for r in range(num_rounds):
q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32)
kv_lens = kv_lens + q_lens
total_q_tokens = q_lens.sum().item()
cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32)
cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0)
cum_seqlens_k = torch.zeros(num_seqs + 1, dtype=torch.int32)
cum_seqlens_k[1:] = torch.cumsum(kv_lens, dim=0)
query_base = torch.randn((total_q_tokens, num_heads, head_size))
round_block_tables_list = []
for i in range(num_seqs):
p_blocks, total_len = manager.allocate_slots(i, q_lens[i].item())
round_block_tables_list.append(p_blocks)
h_len = kv_lens[i].item() - q_lens[i].item()
for t in range(q_lens[i].item()):
logical_pos = h_len + t
b_id = p_blocks[logical_pos // block_size]
off = logical_pos % block_size
persistent_k[b_id, :, off, :] = torch.randn(num_kv_heads, head_size)
persistent_v[b_id, :, off, :] = torch.randn(num_kv_heads, head_size)
max_blks = max(len(t) for t in round_block_tables_list)
padded_tables = torch.tensor(
[t + [0] * (max_blks - len(t)) for t in round_block_tables_list]
)
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype)
test_cases.append(
TestCase(
inputs=[
TensorSpec.from_tensor(
query_base.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=query_base.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_k.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_k.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
persistent_v.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=persistent_v.clone(),
dtype=dtype,
),
TensorSpec.from_tensor(
padded_tables.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=padded_tables.clone(),
dtype=infinicore.int32,
),
# TensorSpec.from_tensor(
# kv_lens.shape,
# init_mode=TensorInitializer.MANUAL,
# set_tensor=kv_lens.clone(),
# dtype=infinicore.int64,
# ),
TensorSpec.from_tensor(
cum_seqlens_q.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=cum_seqlens_q.clone(),
dtype=infinicore.int32,
),
TensorSpec.from_tensor(
cum_seqlens_k.shape,
init_mode=TensorInitializer.MANUAL,
set_tensor=cum_seqlens_k.clone(),
dtype=infinicore.int32,
),
],
kwargs={
"scale": scale,
"max_seqlen_q": max_step_len + num_rounds,
"max_seqlen_k": max_step_len + num_rounds,
},
tolerance=tolerance,
description=f"MHA_Varlen_Round_{r}_{str(dtype).split('.')[-1]}",
)
)
return test_cases
def ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, cum_seqlens_q, cum_seqlens_k, scale
):
output = torch.zeros_like(query)
num_seqs = len(cum_seqlens_q) - 1
block_size = k_cache.shape[2]
for i in range(num_seqs):
q_start, q_end = cum_seqlens_q[i].item(), cum_seqlens_q[i + 1].item()
cur_q = query[q_start:q_end]
q_len = q_end - q_start
h_len = (cum_seqlens_k[i + 1].item() - cum_seqlens_k[i].item()) - q_len
total_len = h_len + q_len
table = block_tables[i]
keys, values = [], []
for j in range(total_len):
b_id = table[j // block_size].item()
off = j % block_size
keys.append(k_cache[b_id, :, off, :])
values.append(v_cache[b_id, :, off, :])
K = torch.stack(keys, dim=0)
V = torch.stack(values, dim=0)
scores = torch.einsum("qhd,khd->hqk", cur_q.float(), K.float()) * scale
mask = torch.full((q_len, total_len), float("-inf"), device=query.device)
for t in range(q_len):
mask[t, : h_len + t + 1] = 0.0
attn = torch.softmax(scores + mask.unsqueeze(0), dim=-1).to(query.dtype)
output[q_start:q_end] = torch.einsum("hqk,khd->qhd", attn, V)
return output
class OpTest(BaseOperatorTest):
def __init__(self):
super().__init__("PagedAttentionPrefill")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(
self,
query,
k_cache,
v_cache,
block_tables,
cum_seqlens_q,
cum_seqlens_k,
scale=1.0,
max_seqlen_q=0,
max_seqlen_k=0,
):
return ref_paged_attention_multi_turn(
query, k_cache, v_cache, block_tables, cum_seqlens_q, cum_seqlens_k, scale
)
def infinicore_operator(
self,
query,
k_cache,
v_cache,
block_tables,
cum_seqlens_q,
cum_seqlens_k,
scale=1.0,
max_seqlen_q=0,
max_seqlen_k=0,
):
out = infinicore.mha_varlen(
query,
k_cache.permute([0, 2, 1, 3]),
v_cache.permute([0, 2, 1, 3]),
cum_seqlens_q,
cum_seqlens_k,
block_tables,
max_seqlen_q,
max_seqlen_k,
alibi_slopes=None,
scale=scale,
)
infinicore.sync_stream()
return out
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
......@@ -439,8 +439,35 @@ target("infinicore_cpp_api")
add_linkdirs(INFINI_ROOT.."/lib")
add_links("infiniop", "infinirt", "infiniccl")
-- ==============================
-- LibTorch integration
-- ==============================
local LIBTORCH_ROOT = ("/home/panzezhong/.conda/envs/myenv/lib/python3.13/site-packages/torch")
-- headers
add_includedirs(
path.join(LIBTORCH_ROOT, "include"),
path.join(LIBTORCH_ROOT, "include/torch/csrc/api/include"),
{ public = true }
)
-- libraries
add_linkdirs(path.join(LIBTORCH_ROOT, "lib"))
-- core ATen / Torch libs
add_links(
"torch",
"c10",
"torch_cuda",
"c10_cuda"
)
-- Flash attention lib
add_linkdirs("/home/panzezhong/Projects/InfiniCore/third_party/flash-attention/csrc/build")
add_links("flash_attn")
-- Add InfiniCore C++ source files (needed for RoPE and other nn modules)
add_files("src/infinicore/*.cc")
add_files("src/infinicore/adaptor/*.cc")
add_files("src/infinicore/context/*.cc")
add_files("src/infinicore/context/*/*.cc")
add_files("src/infinicore/tensor/*.cc")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment