Unverified Commit 8d09630a authored by gongchensu's avatar gongchensu Committed by GitHub
Browse files

Merge branch 'demo131' into Issue/862

parents ab52dead 012df56c
#include "../../../../utils.h"
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../tensor.h"
#include "../cuda/embedding_kernel.cuh"
#include "embedding_nvidia.cuh"
#include <cuda_runtime.h>
template <typename T, typename IndexType>
INFINIOP_CUDA_KERNEL embeddingKernel(
T *__restrict__ output,
const IndexType *__restrict__ indices,
const T *__restrict__ weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
// Calculate global thread index
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_indices) {
// Get the index value
IndexType index_val = __ldg(&indices[idx]);
// Bounds check - handle negative indices gracefully
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
// Copy embedding vector from weight to output
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
T *dst = output + idx * embedding_dim;
// Choose optimal copy strategy based on type and alignment
if constexpr (std::is_same_v<T, float>) {
// Check alignment for float4 (16 bytes)
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
// Try float2 if not aligned to 16 bytes
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, half>) {
// Use half2 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// Use bfloat162 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else {
// Fallback to scalar copy with __ldg
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
}
}
}
namespace op::embedding::nvidia {
struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc) {
auto input_shape = input_desc->shape();
auto weight_shape = weight_desc->shape();
// Validate shapes
CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Check output shape matches input shape + embedding_dim
auto output_shape = output_desc->shape();
size_t embedding_dim = weight_shape[1];
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);
for (size_t i = 0; i < input_shape.size(); ++i) {
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
}
// Validate dtypes
auto input_dtype = input_desc->dtype();
auto weight_dtype = weight_desc->dtype();
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
// Calculate number of indices (supporting batch dimension)
size_t num_indices = 1;
for (auto dim : input_shape) {
num_indices *= dim;
}
size_t vocab_size = weight_shape[0];
*desc_ptr = new Descriptor(
num_indices,
embedding_dim,
vocab_size,
input_dtype,
weight_dtype,
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *output,
const void *input,
const void *weight,
void *stream) const {
if (_num_indices == 0) {
return INFINI_STATUS_SUCCESS;
}
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// Dynamic block size optimization based on embedding_dim
// Smaller embedding_dim benefits from larger block size (better occupancy)
// Larger embedding_dim benefits from smaller block size (more registers per thread)
size_t block_size = 256; // Default
if (_embedding_dim <= 64) {
block_size = 512; // Small embedding_dim: use larger block for better occupancy
} else if (_embedding_dim >= 1024) {
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
}
size_t grid_size = (_num_indices + block_size - 1) / block_size;
// Launch kernel based on dtypes
if (_input_dtype == INFINI_DTYPE_I32) {
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
if (_weight_dtype == INFINI_DTYPE_F32) {
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
reinterpret_cast<float *>(output),
indices_ptr,
reinterpret_cast<const float *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
reinterpret_cast<half *>(output),
indices_ptr,
reinterpret_cast<const half *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
embeddingKernel<cuda_bfloat16, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
reinterpret_cast<cuda_bfloat16 *>(output),
indices_ptr,
reinterpret_cast<const cuda_bfloat16 *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_input_dtype == INFINI_DTYPE_I64) {
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
if (_weight_dtype == INFINI_DTYPE_F32) {
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
reinterpret_cast<float *>(output),
indices_ptr,
reinterpret_cast<const float *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
reinterpret_cast<half *>(output),
indices_ptr,
reinterpret_cast<const half *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
embeddingKernel<cuda_bfloat16, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
reinterpret_cast<cuda_bfloat16 *>(output),
indices_ptr,
reinterpret_cast<const cuda_bfloat16 *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check for kernel launch errors
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::embedding::nvidia
#ifndef __EMBEDDING_CUDA_H__
#define __EMBEDDING_CUDA_H__
#include "../embedding.h"
DESCRIPTOR(nvidia)
#endif // __EMBEDDING_CUDA_H__
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/embedding.h"
#ifdef ENABLE_CPU_API
#include "cpu/embedding_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/embedding_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/embedding_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/embedding_moore.h"
#endif
__C infiniStatus_t infiniopCreateEmbeddingDescriptor(
infiniopHandle_t handle,
infiniopEmbeddingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::embedding::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::embedding::NAMESPACE::Descriptor **>(desc_ptr), \
output_desc, \
input_desc, \
weight_desc)
switch (handle->device) {
#ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopEmbedding(
infiniopEmbeddingDescriptor_t desc,
void *output,
const void *input,
const void *weight,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::embedding::NAMESPACE::Descriptor *>(desc) \
->calculate(output, input, weight, stream)
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::embedding::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
DESTROY(INFINI_DEVICE_CPU, cpu);
#endif
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
import ninetoothed
from . import flash_attention
from .flash_attention import CausalVariant
import infiniop.ninetoothed.build
import torch
import os
def build():
env_vars_to_check = ["MACA_HOME", "MACA_PATH", "MACA_ROOT"]
if any(var in os.environ for var in env_vars_to_check):
return
with_kv_cache_values = (0,)
emb_dim_values = (16, 32, 64, 128, 256)
is_causal_values = (0, 1)
with_attn_mask_values = (0,)
causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT)
dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32)
block_size_m_values = (256,)
block_size_n_values = (64,)
constexpr_param_grid = {
"with_kv_cache": with_kv_cache_values,
"emb_dim": emb_dim_values,
"is_causal": is_causal_values,
"with_attn_mask": with_attn_mask_values,
"causal_variant": causal_variant_values,
"dtype": dtype_values,
"block_size_m": block_size_m_values,
"block_size_n": block_size_n_values,
}
infiniop.ninetoothed.build.build(
flash_attention.premake,
constexpr_param_grid,
caller="cuda",
op_name="flash_attention",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__
#define __FLASH_ATTENTION_DESCRIPTOR_H__
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/flash_attention.h"
#include "../../../ninetoothed/utils.h"
namespace op::flash_attention::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(infiniopHandle_t handle,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) : InfiniopDescriptor{handle->device, handle->device_id},
_query_shape{q_desc->shape()},
_query_strides{q_desc->strides()},
_key_shape{k_desc->shape()},
_key_strides{k_desc->strides()},
_value_shape{v_desc->shape()},
_value_strides{v_desc->strides()},
_total_kv_shape{total_kv_len->shape()},
_total_kv_strides{total_kv_len->strides()},
_output_strides{out_desc->strides()},
_dtype{q_desc->dtype()},
_scale{scale},
_is_causal{is_causal} {
}
~Descriptor() = default;
size_t get_workspace_size() const {
return 0;
}
infiniStatus_t calculate(void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) const {
uint64_t empty_shape[4];
int64_t empty_strides[4];
auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}};
NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides};
NineToothedTensor is_causal;
NineToothedTensor scale{const_cast<double *>(&_scale), nullptr, nullptr};
auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}};
NineToothedTensor with_attn_mask;
NineToothedTensor causal_variant;
const auto with_kv_cache_{0};
const auto emb_dim_{_query_shape[3]};
const auto is_causal_{_is_causal};
const auto with_attn_mask_{0};
const auto causal_variant_{2};
const auto dtype_{_dtype};
constexpr auto block_size_m_{256};
constexpr auto block_size_n_{64};
if (launch_flash_attention(stream,
query,
key,
value,
total_kv_length,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache_,
emb_dim_,
is_causal_,
with_attn_mask_,
causal_variant_,
dtype_,
block_size_m_,
block_size_n_)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
static infiniStatus_t create(infiniopHandle_t handle,
Descriptor **desc,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) {
*desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal};
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> _query_shape;
std::vector<Stride> _query_strides;
std::vector<Size> _key_shape;
std::vector<Stride> _key_strides;
std::vector<Size> _value_shape;
std::vector<Stride> _value_strides;
std::vector<Size> _total_kv_shape;
std::vector<Stride> _total_kv_strides;
std::vector<Stride> _output_strides;
infiniDtype_t _dtype;
double _scale;
char _is_causal;
};
} // namespace op::flash_attention::ninetoothed
#endif // __FLASH_ATTENTION_DESCRIPTOR_H__
import enum
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
BLOCK_SIZE_M = ninetoothed.block_size()
BLOCK_SIZE_N = ninetoothed.block_size()
class CausalVariant(enum.IntEnum):
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
UPPER_LEFT = enum.auto()
LOWER_RIGHT = enum.auto()
def arrangement(
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache,
block_size_m=None,
block_size_n=None,
):
def arrange_query_or_output(input):
arranged = input.tile((1, 1, block_size_m, -1)).tile(
(1, query.shape[-3] // key.shape[-3], 1, 1)
)
arranged.dtype = arranged.dtype.squeeze((0, 2, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
def arrange_key_or_value(input):
arranged = (
input.tile((1, 1, block_size_n, -1))
.tile((1, 1, -1, -1))
.expand((-1, -1, query_arranged.shape[-2], -1))
)
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
def arrange_total_kv_len(input, shape):
arranged = input.tile((1,))
arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape)
return arranged
def arrange_present_key_or_present_value(input):
arranged = input.tile((1, 1, block_size_m, block_size_n))
arranged.dtype = arranged.dtype.squeeze((0, 1))
return arranged
def arrange_attn_mask(input):
arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1))
arranged.dtype = arranged.dtype.squeeze((0, 1, 2))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
if block_size_n is None:
block_size_n = BLOCK_SIZE_N
query_arranged = arrange_query_or_output(query)
key_arranged = arrange_key_or_value(key)
value_arranged = arrange_key_or_value(value)
total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape)
present_key_arranged = arrange_present_key_or_present_value(present_key)
present_value_arranged = arrange_present_key_or_present_value(present_value)
present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot)
present_value_slot_arranged = arrange_present_key_or_present_value(
present_value_slot
)
attn_mask_arranged = arrange_attn_mask(attn_mask)
is_causal_arranged = is_causal
scale_arranged = scale
output_arranged = arrange_query_or_output(output)
with_attn_mask_arranged = with_attn_mask
causal_variant_arranged = causal_variant
if with_kv_cache:
return (
query_arranged,
key_arranged,
value_arranged,
total_kv_len_arranged,
present_key_arranged,
present_value_arranged,
present_key_slot_arranged,
present_value_slot_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)
return (
query_arranged,
key_arranged,
value_arranged,
total_kv_len_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)
def application_with_kv_cache(
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
present_key_slot = present_key # noqa: F841
present_value_slot = present_value # noqa: F841
application_without_kv_cache(
query,
key,
value,
total_kv_len,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)
def application_without_kv_cache(
query,
key,
value,
total_kv_len,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
actual_kv_len = total_kv_len[0]
for i in range(query.shape[0]):
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32)
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
for j in range(-(-actual_kv_len // key.dtype.shape[0])):
qk = ntl.dot(query_i, ntl.trans(key[j]))
key_pos = key[j].offsets(-2)
qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf"))
if with_attn_mask:
qk += attn_mask[j]
if is_causal:
query_pos = query[i].offsets(-2)
if causal_variant == 2: # CausalVariant.LOWER_RIGHT:
mask = (
query_pos[:, None] + actual_kv_len - query.source.shape[-2]
>= key_pos[None, :]
)
else:
mask = query_pos[:, None] >= key_pos[None, :]
qk = ntl.where(mask, qk, float("-inf"))
next_max = ntl.maximum(max, ntl.max(qk, 1))
stable_qk = ntl.exp2(qk - next_max[:, None])
alpha = ntl.exp2(max - next_max)
acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j])
max = next_max
lse = lse * alpha + ntl.sum(stable_qk, 1)
acc /= lse[:, None]
output[i] = acc # noqa: F841
def premake(
with_kv_cache,
emb_dim=None,
is_causal=None,
with_attn_mask=None,
causal_variant=None,
dtype=None,
block_size_m=None,
block_size_n=None,
):
arrangement_ = functools.partial(
arrangement,
with_kv_cache=with_kv_cache,
block_size_m=block_size_m,
block_size_n=block_size_n,
)
query, key, value, attn_mask, output = (
Tensor(
4,
dtype=dtype,
shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}),
)
for _ in range(5)
)
total_kv_len = Tensor(1, dtype=ninetoothed.int32)
present_key, present_value, present_key_slot, present_value_slot = (
Tensor(4, dtype=dtype) for _ in range(4)
)
scale = Tensor(0, dtype=ninetoothed.float64)
is_causal = Tensor(0, constexpr=True, value=is_causal)
with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask)
causal_variant = Tensor(0, constexpr=True, value=causal_variant)
if emb_dim is not None:
for tensor in (query, key, value, attn_mask, output):
tensor.shape = tensor.shape[:-1] + (emb_dim,)
if with_kv_cache:
application = application_with_kv_cache
else:
application = application_without_kv_cache
tensors = (
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)
return arrangement_, application, tensors
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/flash_attention.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
#include "ninetoothed/descriptor.h"
#endif
#endif
__C infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::flash_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
q_desc, \
k_desc, \
v_desc, \
total_kv_len, \
scale, \
is_causal);
switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size) {
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET_SIZE
}
__C infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream);
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gelu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/gelu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -49,6 +49,9 @@ __C infiniStatus_t infiniopCreateGeluDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -83,6 +86,10 @@ __C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, s
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -124,6 +131,9 @@ __C infiniStatus_t infiniopGelu(
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -160,6 +170,9 @@ infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -15,8 +15,8 @@ struct Descriptor::Opaque {
cnnlDestroyTensorDescriptor(a);
cnnlDestroyTensorDescriptor(b);
cnnlDestroyTensorDescriptor(c);
cnnlMatMulDescDestroy(op);
cnnlMatMulAlgoDestroy(algo);
cnnlDestroyMatMulDescriptor(op);
cnnlDestroyMatMulAlgo(algo);
cnnlDestroyMatMulHeuristicResult(algoResult);
}
};
......@@ -85,8 +85,8 @@ infiniStatus_t Descriptor::create(
cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
CHECK_BANG(cnnlMatMulDescCreate(&op));
CHECK_BANG(cnnlMatMulAlgoCreate(&algo));
CHECK_BANG(cnnlCreateMatMulDescriptor(&op));
CHECK_BANG(cnnlCreateMatMulAlgo(&algo));
CHECK_BANG(cnnlCreateMatMulHeuristicResult(&algoResult));
int32_t use_stride = true;
CHECK_BANG(cnnlSetMatMulDescAttr(
......@@ -101,7 +101,7 @@ infiniStatus_t Descriptor::create(
(cnrtQueue_t) nullptr,
[&](cnnlHandle_t _handle) {
CHECK_BANG(
cnnlGetBatchMatMulAlgoHeuristic(
cnnlGetBatchMatMulExAlgoHeuristic(
_handle,
op, a, b, c,
NULL, 1, &algoResult, &count));
......@@ -109,7 +109,7 @@ infiniStatus_t Descriptor::create(
}));
size_t workspace_size;
CHECK_BANG(cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size));
CHECK_BANG(cnnlGetBatchMatMulExHeuristicResult(algoResult, algo, &workspace_size));
*desc_ptr = new Descriptor(
dtype, info, workspace_size,
......@@ -135,7 +135,7 @@ infiniStatus_t Descriptor::calculate(
CHECK_STATUS(_opaque->internal->useCnnl(
(cnrtQueue_t)stream,
[&](cnnlHandle_t handle) {
CHECK_BANG(cnnlBatchMatMulBCast_v2(
CHECK_BANG(cnnlBatchMatMulEx(
handle,
_opaque->op,
_opaque->algo,
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gemm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/gemm_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
......@@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -102,6 +105,9 @@ infiniopGetGemmWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -160,6 +166,9 @@ __C infiniStatus_t infiniopGemm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -208,6 +217,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......
import ninetoothed
from . import kv_caching
import infiniop.ninetoothed.build
def build():
dtype_values = (
ninetoothed.float16,
ninetoothed.bfloat16,
ninetoothed.float32,
)
constexpr_param_grid = {
"emb_dim": (1, 16, 32, 64, 128, 256),
"dtype": dtype_values,
"block_size_m": (64,),
"block_size_n": (64,),
}
infiniop.ninetoothed.build.build(
kv_caching.premake,
constexpr_param_grid,
caller="cuda",
op_name="kv_caching",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/kv_caching.h"
#include "../../../ninetoothed/utils.h"
namespace op::kv_caching::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(
infiniopHandle_t handle,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id},
k_cache_shape_{k_cache_desc->shape()},
k_cache_strides_{k_cache_desc->strides()},
v_cache_shape_{v_cache_desc->shape()},
v_cache_strides_{v_cache_desc->strides()},
k_shape_{k_desc->shape()},
k_strides_{k_desc->strides()},
v_shape_{v_desc->shape()},
v_strides_{v_desc->strides()},
past_kv_lengths_shape_{past_kv_lengths_desc->shape()},
past_kv_lengths_strides_{past_kv_lengths_desc->strides()},
dtype_{k_desc->dtype()} {}
~Descriptor() = default;
size_t get_workspace_size() const { return 0; };
static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
*desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths};
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t calculate(
void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream) const {
auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}};
auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}};
auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}};
auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}};
auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}};
if (launch_kv_caching(stream,
k_cache_nt,
v_cache_nt,
k_nt,
v_nt,
past_kv_lengths_nt,
k_shape_[3],
dtype_,
64, 64)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> k_cache_shape_;
std::vector<Stride> k_cache_strides_;
std::vector<Size> v_cache_shape_;
std::vector<Stride> v_cache_strides_;
std::vector<Size> k_shape_;
std::vector<Stride> k_strides_;
std::vector<Size> v_shape_;
std::vector<Stride> v_strides_;
std::vector<Size> past_kv_lengths_shape_;
std::vector<Stride> past_kv_lengths_strides_;
infiniDtype_t dtype_;
};
} // namespace op::kv_caching::ninetoothed
#endif // KV_CACHING_H
import functools
import ninetoothed
from ninetoothed import Tensor
def arrangement(
k_cache,
v_cache,
k,
v,
past_lengths,
block_size_m=ninetoothed.block_size(),
block_size_n=ninetoothed.block_size(),
):
k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
past_lengths_arranged = (
past_lengths.tile((1,))
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.unsqueeze(4)
.expand((-1, *k_arranged.shape))
)
return (
k_cache_arranged,
v_cache_arranged,
k_arranged,
v_arranged,
past_lengths_arranged,
)
def application(k_cache, v_cache, k, v, past_lengths):
pos = past_lengths
for i in range(k.shape[-2]):
k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0]
v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0]
def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None):
arrangement_ = functools.partial(
arrangement, block_size_m=block_size_m, block_size_n=block_size_n
)
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256})
tensors = (
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(1, dtype=ninetoothed.int64),
)
if emb_dim is not None:
for tensor in tensors:
tensor.shape = tensor.shape[:-1] + (emb_dim,)
return arrangement_, application, tensors
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
__C infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::kv_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache, \
v_cache, \
k, \
v, \
past_kv_lengths)
switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
infiniopKVCachingDescriptor_t desc,
size_t *size) {
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET_SIZE
}
__C infiniStatus_t infiniopKVCaching(
infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream)
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyKVCachingDescriptor(
infiniopKVCachingDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/layer_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/layer_norm_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateLayerNormDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -76,6 +79,9 @@ __C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -126,6 +132,9 @@ __C infiniStatus_t infiniopLayerNorm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -156,6 +165,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/logsoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/logsoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
......@@ -66,6 +69,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
......@@ -101,6 +107,9 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
......@@ -131,6 +140,9 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
......
......@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/lp_norm.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/lp_norm_nvidia.cuh"
#endif
......@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateLPNormDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -60,6 +63,9 @@ __C infiniStatus_t infiniopGetLPNormWorkspaceSize(infiniopLPNormDescriptor_t des
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -97,6 +103,9 @@ __C infiniStatus_t infiniopLPNorm(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -124,6 +133,9 @@ infiniopDestroyLPNormDescriptor(infiniopLPNormDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/mul_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/mul_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -48,6 +48,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -85,6 +88,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
......@@ -131,6 +137,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -170,6 +179,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/ones_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/ones_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -49,6 +49,10 @@ __C infiniStatus_t infiniopCreateOnesDescriptor(
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -82,6 +86,10 @@ __C infiniStatus_t infiniopGetOnesWorkspaceSize(infiniopOnesDescriptor_t desc, s
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -123,6 +131,10 @@ __C infiniStatus_t infiniopOnes(
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -158,6 +170,10 @@ infiniopDestroyOnesDescriptor(infiniopOnesDescriptor_t desc) {
#ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_MOORE, moore);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......
#ifndef __PAGED_ATTENTION_KERNEL_CUH__
#define __PAGED_ATTENTION_KERNEL_CUH__
// This kernel is refactored to be high-performance, adopting parallelism strategies
// from industry-standard implementations like vLLM. It fixes functional and performance
// issues in the original draft.
namespace op::paged_attention::cuda {
template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
__device__ void pagedAttentionKernel(
Tdata *out_,
const Tdata *q_,
const Tdata *k_cache_,
const Tdata *v_cache_,
const int64_t *block_tables_,
const int64_t *seq_lens_,
const float *alibi_slopes_,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride) {
//================================================================================
// 1. Setup & Query Loading (No changes in this section)
//================================================================================
const int seq_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int64_t seq_len = seq_lens_[seq_idx];
if (seq_len == 0) {
return;
}
const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
const int64_t *block_table = block_tables_ + seq_idx * max_num_blocks_per_seq;
const Tdata *q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE;
Tdata *out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE;
extern __shared__ char shared_mem_char[];
Tcompute *shared_mem = reinterpret_cast<Tcompute *>(shared_mem_char);
Tcompute *q_shared = shared_mem;
Tcompute *logits = shared_mem + HEAD_SIZE;
// printf("static_cast<Tcompute>(q_ptr[i]);");
for (size_t i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
q_shared[i] = static_cast<Tcompute>(q_ptr[i]);
}
__syncthreads();
//================================================================================
// 2. Compute QK Dot Product & Find Max Logit
//================================================================================
for (size_t token_idx = threadIdx.x; token_idx < seq_len; token_idx += NUM_THREADS) {
const int64_t block_idx = token_idx / block_size;
const int64_t token_in_block_idx = token_idx % block_size;
const int64_t physical_block_num = block_table[block_idx];
const Tdata *k_vec_ptr = k_cache_ + physical_block_num * kv_block_stride + kv_head_idx * kv_head_stride + token_in_block_idx * HEAD_SIZE;
Tcompute qk = 0.0f;
#pragma unroll
for (size_t i = 0; i < HEAD_SIZE / 8; ++i) {
const size_t offset = i * 8;
// 手动展开8次计算
qk += q_shared[offset + 0] * static_cast<Tcompute>(k_vec_ptr[offset + 0]);
qk += q_shared[offset + 1] * static_cast<Tcompute>(k_vec_ptr[offset + 1]);
qk += q_shared[offset + 2] * static_cast<Tcompute>(k_vec_ptr[offset + 2]);
qk += q_shared[offset + 3] * static_cast<Tcompute>(k_vec_ptr[offset + 3]);
qk += q_shared[offset + 4] * static_cast<Tcompute>(k_vec_ptr[offset + 4]);
qk += q_shared[offset + 5] * static_cast<Tcompute>(k_vec_ptr[offset + 5]);
qk += q_shared[offset + 6] * static_cast<Tcompute>(k_vec_ptr[offset + 6]);
qk += q_shared[offset + 7] * static_cast<Tcompute>(k_vec_ptr[offset + 7]);
}
qk *= scale;
if (alibi_slope != 0.0f) {
qk += alibi_slope * (token_idx - seq_len + 1);
}
logits[token_idx] = qk;
}
__syncthreads();
__shared__ Tcompute global_qk_max;
Tcompute global_qk_max_0 = op::common_cuda::reduce_op::max<NUM_THREADS, Tcompute>(logits, seq_len);
if (threadIdx.x == 0) {
global_qk_max = global_qk_max_0;
}
__syncthreads();
//================================================================================
// 3. Compute Softmax (No changes in this section)
//================================================================================
for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) {
Tcompute val = expf(logits[i] - global_qk_max); // 使用全局最大值
logits[i] = val;
}
__syncthreads();
__shared__ Tcompute inv_sum;
Tcompute exp_sum_0 = op::common_cuda::reduce_op::sum<NUM_THREADS, Tcompute, Tcompute>(logits, seq_len);
if (threadIdx.x == 0) {
inv_sum = 1.0f / (exp_sum_0 + 1e-6f);
}
__syncthreads();
for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
//================================================================================
// 4. Aggregate Values (V) weighted by probabilities
//================================================================================
for (size_t h_dim = threadIdx.x; h_dim < HEAD_SIZE; h_dim += NUM_THREADS) {
Tcompute acc = 0.0f;
for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) {
const size_t block_idx = token_idx / block_size;
const size_t token_in_block_idx = token_idx % block_size;
const int64_t physical_block_num = block_table[block_idx];
const Tcompute prob = logits[token_idx];
const Tdata *v_vec_ptr = v_cache_
+ physical_block_num * kv_block_stride
+ kv_head_idx * kv_head_stride
+ token_in_block_idx * HEAD_SIZE;
const Tdata v_val = v_vec_ptr[h_dim];
acc += prob * static_cast<Tcompute>(v_val);
}
out_ptr[h_dim] = static_cast<Tdata>(acc);
}
}
} // namespace op::paged_attention::cuda
#endif // __PAGED_ATTENTION_KERNEL_CUH__
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment