Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
import ninetoothed
from . import flash_attention
from .flash_attention import CausalVariant
import infiniop.ninetoothed.build
import torch
import os
def build():
env_vars_to_check = ["MACA_HOME", "MACA_PATH", "MACA_ROOT"]
if any(var in os.environ for var in env_vars_to_check):
return
with_kv_cache_values = (0,)
emb_dim_values = (16, 32, 64, 128, 256)
is_causal_values = (0, 1)
with_attn_mask_values = (0,)
causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT)
dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32)
block_size_m_values = (256,)
block_size_n_values = (64,)
constexpr_param_grid = {
"with_kv_cache": with_kv_cache_values,
"emb_dim": emb_dim_values,
"is_causal": is_causal_values,
"with_attn_mask": with_attn_mask_values,
"causal_variant": causal_variant_values,
"dtype": dtype_values,
"block_size_m": block_size_m_values,
"block_size_n": block_size_n_values,
}
infiniop.ninetoothed.build.build(
flash_attention.premake,
constexpr_param_grid,
caller="cuda",
op_name="flash_attention",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__
#define __FLASH_ATTENTION_DESCRIPTOR_H__
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/flash_attention.h"
#include "../../../ninetoothed/utils.h"
namespace op::flash_attention::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(infiniopHandle_t handle,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) : InfiniopDescriptor{handle->device, handle->device_id},
_query_shape{q_desc->shape()},
_query_strides{q_desc->strides()},
_key_shape{k_desc->shape()},
_key_strides{k_desc->strides()},
_value_shape{v_desc->shape()},
_value_strides{v_desc->strides()},
_total_kv_shape{total_kv_len->shape()},
_total_kv_strides{total_kv_len->strides()},
_output_strides{out_desc->strides()},
_dtype{q_desc->dtype()},
_scale{scale},
_is_causal{is_causal} {
}
~Descriptor() = default;
size_t get_workspace_size() const {
return 0;
}
infiniStatus_t calculate(void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) const {
uint64_t empty_shape[4];
int64_t empty_strides[4];
auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}};
auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}};
auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}};
auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}};
NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides};
NineToothedTensor is_causal;
NineToothedTensor scale{const_cast<double *>(&_scale), nullptr, nullptr};
auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}};
NineToothedTensor with_attn_mask;
NineToothedTensor causal_variant;
const auto with_kv_cache_{0};
const auto emb_dim_{_query_shape[3]};
const auto is_causal_{_is_causal};
const auto with_attn_mask_{0};
const auto causal_variant_{2};
const auto dtype_{_dtype};
constexpr auto block_size_m_{256};
constexpr auto block_size_n_{64};
if (launch_flash_attention(stream,
query,
key,
value,
total_kv_length,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache_,
emb_dim_,
is_causal_,
with_attn_mask_,
causal_variant_,
dtype_,
block_size_m_,
block_size_n_)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
static infiniStatus_t create(infiniopHandle_t handle,
Descriptor **desc,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
double scale,
char is_causal) {
*desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal};
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> _query_shape;
std::vector<Stride> _query_strides;
std::vector<Size> _key_shape;
std::vector<Stride> _key_strides;
std::vector<Size> _value_shape;
std::vector<Stride> _value_strides;
std::vector<Size> _total_kv_shape;
std::vector<Stride> _total_kv_strides;
std::vector<Stride> _output_strides;
infiniDtype_t _dtype;
double _scale;
char _is_causal;
};
} // namespace op::flash_attention::ninetoothed
#endif // __FLASH_ATTENTION_DESCRIPTOR_H__
import enum
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
BLOCK_SIZE_M = ninetoothed.block_size()
BLOCK_SIZE_N = ninetoothed.block_size()
class CausalVariant(enum.IntEnum):
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
UPPER_LEFT = enum.auto()
LOWER_RIGHT = enum.auto()
def arrangement(
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
with_kv_cache,
block_size_m=None,
block_size_n=None,
):
def arrange_query_or_output(input):
arranged = input.tile((1, 1, block_size_m, -1)).tile(
(1, query.shape[-3] // key.shape[-3], 1, 1)
)
arranged.dtype = arranged.dtype.squeeze((0, 2, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
def arrange_key_or_value(input):
arranged = (
input.tile((1, 1, block_size_n, -1))
.tile((1, 1, -1, -1))
.expand((-1, -1, query_arranged.shape[-2], -1))
)
arranged.dtype = arranged.dtype.squeeze((0, 1, 3))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
def arrange_total_kv_len(input, shape):
arranged = input.tile((1,))
arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape)
return arranged
def arrange_present_key_or_present_value(input):
arranged = input.tile((1, 1, block_size_m, block_size_n))
arranged.dtype = arranged.dtype.squeeze((0, 1))
return arranged
def arrange_attn_mask(input):
arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1))
arranged.dtype = arranged.dtype.squeeze((0, 1, 2))
arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1))
return arranged
if block_size_m is None:
block_size_m = BLOCK_SIZE_M
if block_size_n is None:
block_size_n = BLOCK_SIZE_N
query_arranged = arrange_query_or_output(query)
key_arranged = arrange_key_or_value(key)
value_arranged = arrange_key_or_value(value)
total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape)
present_key_arranged = arrange_present_key_or_present_value(present_key)
present_value_arranged = arrange_present_key_or_present_value(present_value)
present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot)
present_value_slot_arranged = arrange_present_key_or_present_value(
present_value_slot
)
attn_mask_arranged = arrange_attn_mask(attn_mask)
is_causal_arranged = is_causal
scale_arranged = scale
output_arranged = arrange_query_or_output(output)
with_attn_mask_arranged = with_attn_mask
causal_variant_arranged = causal_variant
if with_kv_cache:
return (
query_arranged,
key_arranged,
value_arranged,
total_kv_len_arranged,
present_key_arranged,
present_value_arranged,
present_key_slot_arranged,
present_value_slot_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)
return (
query_arranged,
key_arranged,
value_arranged,
total_kv_len_arranged,
attn_mask_arranged,
is_causal_arranged,
scale_arranged,
output_arranged,
with_attn_mask_arranged,
causal_variant_arranged,
)
def application_with_kv_cache(
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
present_key_slot = present_key # noqa: F841
present_value_slot = present_value # noqa: F841
application_without_kv_cache(
query,
key,
value,
total_kv_len,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)
def application_without_kv_cache(
query,
key,
value,
total_kv_len,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
):
actual_kv_len = total_kv_len[0]
for i in range(query.shape[0]):
query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype)
acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32)
lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32)
max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32)
for j in range(-(-actual_kv_len // key.dtype.shape[0])):
qk = ntl.dot(query_i, ntl.trans(key[j]))
key_pos = key[j].offsets(-2)
qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf"))
if with_attn_mask:
qk += attn_mask[j]
if is_causal:
query_pos = query[i].offsets(-2)
if causal_variant == 2: # CausalVariant.LOWER_RIGHT:
mask = (
query_pos[:, None] + actual_kv_len - query.source.shape[-2]
>= key_pos[None, :]
)
else:
mask = query_pos[:, None] >= key_pos[None, :]
qk = ntl.where(mask, qk, float("-inf"))
next_max = ntl.maximum(max, ntl.max(qk, 1))
stable_qk = ntl.exp2(qk - next_max[:, None])
alpha = ntl.exp2(max - next_max)
acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j])
max = next_max
lse = lse * alpha + ntl.sum(stable_qk, 1)
acc /= lse[:, None]
output[i] = acc # noqa: F841
def premake(
with_kv_cache,
emb_dim=None,
is_causal=None,
with_attn_mask=None,
causal_variant=None,
dtype=None,
block_size_m=None,
block_size_n=None,
):
arrangement_ = functools.partial(
arrangement,
with_kv_cache=with_kv_cache,
block_size_m=block_size_m,
block_size_n=block_size_n,
)
query, key, value, attn_mask, output = (
Tensor(
4,
dtype=dtype,
shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}),
)
for _ in range(5)
)
total_kv_len = Tensor(1, dtype=ninetoothed.int32)
present_key, present_value, present_key_slot, present_value_slot = (
Tensor(4, dtype=dtype) for _ in range(4)
)
scale = Tensor(0, dtype=ninetoothed.float64)
is_causal = Tensor(0, constexpr=True, value=is_causal)
with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask)
causal_variant = Tensor(0, constexpr=True, value=causal_variant)
if emb_dim is not None:
for tensor in (query, key, value, attn_mask, output):
tensor.shape = tensor.shape[:-1] + (emb_dim,)
if with_kv_cache:
application = application_with_kv_cache
else:
application = application_without_kv_cache
tensors = (
query,
key,
value,
total_kv_len,
present_key,
present_value,
present_key_slot,
present_value_slot,
attn_mask,
is_causal,
scale,
output,
with_attn_mask,
causal_variant,
)
return arrangement_, application, tensors
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/flash_attention.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
#include "ninetoothed/descriptor.h"
#endif
#endif
__C infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t total_kv_len,
float scale,
char is_causal) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::flash_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
q_desc, \
k_desc, \
v_desc, \
total_kv_len, \
scale, \
is_causal);
switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size) {
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET_SIZE
}
__C infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
const void *total_kv_len,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream);
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor(
infiniopFlashAttentionDescriptor_t desc) {
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DESTROY
}
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gelu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/gelu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -49,6 +49,9 @@ __C infiniStatus_t infiniopCreateGeluDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -83,6 +86,10 @@ __C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, s
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
......@@ -124,6 +131,9 @@ __C infiniStatus_t infiniopGelu(
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -160,6 +170,9 @@ infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -15,8 +15,8 @@ struct Descriptor::Opaque {
cnnlDestroyTensorDescriptor(a);
cnnlDestroyTensorDescriptor(b);
cnnlDestroyTensorDescriptor(c);
cnnlMatMulDescDestroy(op);
cnnlMatMulAlgoDestroy(algo);
cnnlDestroyMatMulDescriptor(op);
cnnlDestroyMatMulAlgo(algo);
cnnlDestroyMatMulHeuristicResult(algoResult);
}
};
......@@ -85,8 +85,8 @@ infiniStatus_t Descriptor::create(
cnnlMatMulDescriptor_t op;
cnnlMatMulAlgo_t algo;
cnnlMatMulHeuristicResult_t algoResult;
CHECK_BANG(cnnlMatMulDescCreate(&op));
CHECK_BANG(cnnlMatMulAlgoCreate(&algo));
CHECK_BANG(cnnlCreateMatMulDescriptor(&op));
CHECK_BANG(cnnlCreateMatMulAlgo(&algo));
CHECK_BANG(cnnlCreateMatMulHeuristicResult(&algoResult));
int32_t use_stride = true;
CHECK_BANG(cnnlSetMatMulDescAttr(
......@@ -101,7 +101,7 @@ infiniStatus_t Descriptor::create(
(cnrtQueue_t) nullptr,
[&](cnnlHandle_t _handle) {
CHECK_BANG(
cnnlGetBatchMatMulAlgoHeuristic(
cnnlGetBatchMatMulExAlgoHeuristic(
_handle,
op, a, b, c,
NULL, 1, &algoResult, &count));
......@@ -109,7 +109,7 @@ infiniStatus_t Descriptor::create(
}));
size_t workspace_size;
CHECK_BANG(cnnlGetBatchMatMulHeuristicResult(algoResult, algo, &workspace_size));
CHECK_BANG(cnnlGetBatchMatMulExHeuristicResult(algoResult, algo, &workspace_size));
*desc_ptr = new Descriptor(
dtype, info, workspace_size,
......@@ -135,7 +135,7 @@ infiniStatus_t Descriptor::calculate(
CHECK_STATUS(_opaque->internal->useCnnl(
(cnrtQueue_t)stream,
[&](cnnlHandle_t handle) {
CHECK_BANG(cnnlBatchMatMulBCast_v2(
CHECK_BANG(cnnlBatchMatMulEx(
handle,
_opaque->op,
_opaque->algo,
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gemm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/gemm_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
......@@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -102,6 +105,9 @@ infiniopGetGemmWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -160,6 +166,9 @@ __C infiniStatus_t infiniopGemm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -208,6 +217,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......
import ninetoothed
from . import kv_caching
import infiniop.ninetoothed.build
def build():
dtype_values = (
ninetoothed.float16,
ninetoothed.bfloat16,
ninetoothed.float32,
)
constexpr_param_grid = {
"emb_dim": (1, 16, 32, 64, 128, 256),
"dtype": dtype_values,
"block_size_m": (64,),
"block_size_n": (64,),
}
infiniop.ninetoothed.build.build(
kv_caching.premake,
constexpr_param_grid,
caller="cuda",
op_name="kv_caching",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/kv_caching.h"
#include "../../../ninetoothed/utils.h"
namespace op::kv_caching::ninetoothed {
class Descriptor final : public InfiniopDescriptor {
public:
Descriptor(
infiniopHandle_t handle,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id},
k_cache_shape_{k_cache_desc->shape()},
k_cache_strides_{k_cache_desc->strides()},
v_cache_shape_{v_cache_desc->shape()},
v_cache_strides_{v_cache_desc->strides()},
k_shape_{k_desc->shape()},
k_strides_{k_desc->strides()},
v_shape_{v_desc->shape()},
v_strides_{v_desc->strides()},
past_kv_lengths_shape_{past_kv_lengths_desc->shape()},
past_kv_lengths_strides_{past_kv_lengths_desc->strides()},
dtype_{k_desc->dtype()} {}
~Descriptor() = default;
size_t get_workspace_size() const { return 0; };
static infiniStatus_t create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
*desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths};
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t calculate(
void *workspace, size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream) const {
auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}};
auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}};
auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}};
auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}};
auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}};
if (launch_kv_caching(stream,
k_cache_nt,
v_cache_nt,
k_nt,
v_nt,
past_kv_lengths_nt,
k_shape_[3],
dtype_,
64, 64)) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
return INFINI_STATUS_SUCCESS;
}
private:
using Size = ::ninetoothed::Tensor<>::Size;
using Stride = ::ninetoothed::Tensor<>::Stride;
std::vector<Size> k_cache_shape_;
std::vector<Stride> k_cache_strides_;
std::vector<Size> v_cache_shape_;
std::vector<Stride> v_cache_strides_;
std::vector<Size> k_shape_;
std::vector<Stride> k_strides_;
std::vector<Size> v_shape_;
std::vector<Stride> v_strides_;
std::vector<Size> past_kv_lengths_shape_;
std::vector<Stride> past_kv_lengths_strides_;
infiniDtype_t dtype_;
};
} // namespace op::kv_caching::ninetoothed
#endif // KV_CACHING_H
import functools
import ninetoothed
from ninetoothed import Tensor
def arrangement(
k_cache,
v_cache,
k,
v,
past_lengths,
block_size_m=ninetoothed.block_size(),
block_size_n=ninetoothed.block_size(),
):
k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1))
past_lengths_arranged = (
past_lengths.tile((1,))
.unsqueeze(1)
.unsqueeze(2)
.unsqueeze(3)
.unsqueeze(4)
.expand((-1, *k_arranged.shape))
)
return (
k_cache_arranged,
v_cache_arranged,
k_arranged,
v_arranged,
past_lengths_arranged,
)
def application(k_cache, v_cache, k, v, past_lengths):
pos = past_lengths
for i in range(k.shape[-2]):
k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0]
v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0]
def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None):
arrangement_ = functools.partial(
arrangement, block_size_m=block_size_m, block_size_n=block_size_n
)
shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256})
tensors = (
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(4, dtype=dtype, shape_options=shape_options),
Tensor(1, dtype=ninetoothed.int64),
)
if emb_dim is not None:
for tensor in tensors:
tensor.shape = tensor.shape[:-1] + (emb_dim,)
return arrangement_, application, tensors
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
__C infiniStatus_t infiniopCreateKVCachingDescriptor(
infiniopHandle_t handle,
infiniopKVCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_cache,
infiniopTensorDescriptor_t v_cache,
infiniopTensorDescriptor_t k,
infiniopTensorDescriptor_t v,
infiniopTensorDescriptor_t past_kv_lengths) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::kv_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache, \
v_cache, \
k, \
v, \
past_kv_lengths)
switch (handle->device) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CREATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CREATE
}
__C infiniStatus_t infiniopGetKVCachingWorkspaceSize(
infiniopKVCachingDescriptor_t desc,
size_t *size) {
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef GET_SIZE
}
__C infiniStatus_t infiniopKVCaching(
infiniopKVCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *past_kv_lengths,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream)
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef CALCULATE
}
__C infiniStatus_t infiniopDestroyKVCachingDescriptor(
infiniopKVCachingDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch (desc->device_type) {
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE(INFINI_DEVICE_NVIDIA, ninetoothed);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed);
#endif
#if defined(ENABLE_METAX_API)
DELETE(INFINI_DEVICE_METAX, ninetoothed);
#endif
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
#undef DELETE
}
......@@ -255,6 +255,8 @@ infiniStatus_t Descriptor::calculate(
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_LAYER_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/layer_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/layer_norm_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateLayerNormDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -76,6 +79,9 @@ __C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -126,6 +132,9 @@ __C infiniStatus_t infiniopLayerNorm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -156,12 +165,18 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -56,6 +56,9 @@ __device__ void logSoftmaxKernel(
}
#if CUDART_VERSION >= 12090
max_val = BlockReduce(temp_storage).Reduce(max_val, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
max_val = BlockReduce(temp_storage).Reduce(
max_val, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
max_val = BlockReduce(temp_storage).Reduce(max_val, cub::Max());
#endif
......
......@@ -117,6 +117,11 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_2048>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
_info.y_stride_b, _info.y_stride_p, _info.x_stride_b, _info.x_stride_p,
_info.y_stride_0, _info.y_stride_1, _info.x_stride_0, _info.x_stride_1, stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(
y, x, _info.x_dtype, _info.y_dtype, _info.batch_size, _info.probs_size, _info.ndim, _info.seq_len,
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/logsoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/logsoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -36,8 +36,11 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
......@@ -66,8 +69,11 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
......@@ -101,8 +107,11 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
......@@ -131,8 +140,11 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
......
......@@ -19,6 +19,9 @@ __device__ void blockLPNormKernel(
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
float max_block = BlockReduce(temp_storage).Reduce(
local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
......@@ -75,6 +78,9 @@ __device__ void blockLPNormStridesKernel(
__shared__ float global_max;
#if CUDART_VERSION >= 12090
float max_block = BlockReduce(temp_storage).Reduce(local_max, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
float max_block = BlockReduce(temp_storage).Reduce(
local_max, [](const float &a, const float &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
float max_block = BlockReduce(temp_storage).Reduce(local_max, cub::Max());
#endif
......
......@@ -155,6 +155,8 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_2048) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_2048)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CALCULATE_LP_NORM_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
......
......@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/lp_norm.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/lp_norm_nvidia.cuh"
#endif
......@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateLPNormDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -60,6 +63,9 @@ __C infiniStatus_t infiniopGetLPNormWorkspaceSize(infiniopLPNormDescriptor_t des
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -97,6 +103,9 @@ __C infiniStatus_t infiniopLPNorm(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -124,6 +133,9 @@ infiniopDestroyLPNormDescriptor(infiniopLPNormDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/mul_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/mul_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -48,6 +48,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -85,6 +88,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
......@@ -131,6 +137,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
......@@ -170,6 +179,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_METAX, metax);
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment