Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8d09630a
Unverified
Commit
8d09630a
authored
Feb 11, 2026
by
gongchensu
Committed by
GitHub
Feb 11, 2026
Browse files
Merge branch 'demo131' into Issue/862
parents
ab52dead
012df56c
Changes
387
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1568 additions
and
14 deletions
+1568
-14
src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu
src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu
+224
-0
src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh
src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh
+8
-0
src/infiniop/ops/embedding/operator.cc
src/infiniop/ops/embedding/operator.cc
+154
-0
src/infiniop/ops/flash_attention/ninetoothed/build.py
src/infiniop/ops/flash_attention/ninetoothed/build.py
+44
-0
src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
+147
-0
src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py
...finiop/ops/flash_attention/ninetoothed/flash_attention.py
+281
-0
src/infiniop/ops/flash_attention/operator.cc
src/infiniop/ops/flash_attention/operator.cc
+121
-0
src/infiniop/ops/gelu/operator.cc
src/infiniop/ops/gelu/operator.cc
+14
-1
src/infiniop/ops/gemm/bang/gemm_bang.cc
src/infiniop/ops/gemm/bang/gemm_bang.cc
+7
-7
src/infiniop/ops/gemm/operator.cc
src/infiniop/ops/gemm/operator.cc
+13
-1
src/infiniop/ops/kv_caching/ninetoothed/build.py
src/infiniop/ops/kv_caching/ninetoothed/build.py
+27
-0
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h
+101
-0
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py
+66
-0
src/infiniop/ops/kv_caching/operator.cc
src/infiniop/ops/kv_caching/operator.cc
+143
-0
src/infiniop/ops/layer_norm/operator.cc
src/infiniop/ops/layer_norm/operator.cc
+13
-1
src/infiniop/ops/logsoftmax/operator.cc
src/infiniop/ops/logsoftmax/operator.cc
+13
-1
src/infiniop/ops/lp_norm/operator.cc
src/infiniop/ops/lp_norm/operator.cc
+13
-1
src/infiniop/ops/mul/operator.cc
src/infiniop/ops/mul/operator.cc
+13
-1
src/infiniop/ops/ones/operator.cc
src/infiniop/ops/ones/operator.cc
+17
-1
src/infiniop/ops/paged_attention/cuda/kernel.cuh
src/infiniop/ops/paged_attention/cuda/kernel.cuh
+149
-0
No files found.
src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu
0 → 100644
View file @
8d09630a
#include "../../../../utils.h"
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../tensor.h"
#include "../cuda/embedding_kernel.cuh"
#include "embedding_nvidia.cuh"
#include <cuda_runtime.h>
template
<
typename
T
,
typename
IndexType
>
INFINIOP_CUDA_KERNEL
embeddingKernel
(
T
*
__restrict__
output
,
const
IndexType
*
__restrict__
indices
,
const
T
*
__restrict__
weight
,
size_t
num_indices
,
size_t
embedding_dim
,
size_t
vocab_size
)
{
// Calculate global thread index
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<
num_indices
)
{
// Get the index value
IndexType
index_val
=
__ldg
(
&
indices
[
idx
]);
// Bounds check - handle negative indices gracefully
if
(
index_val
>=
0
&&
static_cast
<
size_t
>
(
index_val
)
<
vocab_size
)
{
// Copy embedding vector from weight to output
const
T
*
src
=
weight
+
static_cast
<
size_t
>
(
index_val
)
*
embedding_dim
;
T
*
dst
=
output
+
idx
*
embedding_dim
;
// Choose optimal copy strategy based on type and alignment
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
// Check alignment for float4 (16 bytes)
bool
aligned_16
=
is_aligned
(
src
,
16
)
&&
is_aligned
(
dst
,
16
);
if
(
aligned_16
&&
embedding_dim
>=
4
&&
embedding_dim
%
4
==
0
)
{
copyVectorizedFloat4
<
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
else
if
(
embedding_dim
>=
2
&&
embedding_dim
%
2
==
0
)
{
// Try float2 if not aligned to 16 bytes
copyVectorizedFloat2
<
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
else
{
copyScalar
<
T
,
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
// Use half2 for vectorized access
if
(
embedding_dim
>=
2
&&
embedding_dim
%
2
==
0
)
{
copyVectorizedHalf2
<
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
else
{
copyScalar
<
T
,
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
cuda_bfloat16
>
)
{
// Use bfloat162 for vectorized access
if
(
embedding_dim
>=
2
&&
embedding_dim
%
2
==
0
)
{
copyVectorizedBFloat162
<
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
else
{
copyScalar
<
T
,
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
}
else
{
// Fallback to scalar copy with __ldg
copyScalar
<
T
,
IndexType
>
(
dst
,
src
,
embedding_dim
);
}
}
}
}
namespace
op
::
embedding
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
output_desc
,
infiniopTensorDescriptor_t
input_desc
,
infiniopTensorDescriptor_t
weight_desc
)
{
auto
input_shape
=
input_desc
->
shape
();
auto
weight_shape
=
weight_desc
->
shape
();
// Validate shapes
CHECK_OR_RETURN
(
weight_shape
.
size
()
==
2
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
CHECK_OR_RETURN
(
output_desc
->
shape
().
size
()
==
input_shape
.
size
()
+
1
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
// Check output shape matches input shape + embedding_dim
auto
output_shape
=
output_desc
->
shape
();
size_t
embedding_dim
=
weight_shape
[
1
];
CHECK_OR_RETURN
(
output_shape
.
back
()
==
embedding_dim
,
INFINI_STATUS_BAD_TENSOR_SHAPE
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
CHECK_OR_RETURN
(
output_shape
[
i
]
==
input_shape
[
i
],
INFINI_STATUS_BAD_TENSOR_SHAPE
);
}
// Validate dtypes
auto
input_dtype
=
input_desc
->
dtype
();
auto
weight_dtype
=
weight_desc
->
dtype
();
CHECK_OR_RETURN
(
input_dtype
==
INFINI_DTYPE_I32
||
input_dtype
==
INFINI_DTYPE_I64
,
INFINI_STATUS_BAD_TENSOR_DTYPE
);
CHECK_OR_RETURN
(
weight_dtype
==
INFINI_DTYPE_F32
||
weight_dtype
==
INFINI_DTYPE_F16
||
weight_dtype
==
INFINI_DTYPE_BF16
,
INFINI_STATUS_BAD_TENSOR_DTYPE
);
CHECK_OR_RETURN
(
output_desc
->
dtype
()
==
weight_dtype
,
INFINI_STATUS_BAD_TENSOR_DTYPE
);
// Calculate number of indices (supporting batch dimension)
size_t
num_indices
=
1
;
for
(
auto
dim
:
input_shape
)
{
num_indices
*=
dim
;
}
size_t
vocab_size
=
weight_shape
[
0
];
*
desc_ptr
=
new
Descriptor
(
num_indices
,
embedding_dim
,
vocab_size
,
input_dtype
,
weight_dtype
,
new
Opaque
{
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle
)
->
internal
()},
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
output
,
const
void
*
input
,
const
void
*
weight
,
void
*
stream
)
const
{
if
(
_num_indices
==
0
)
{
return
INFINI_STATUS_SUCCESS
;
}
auto
cuda_stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream
);
// Dynamic block size optimization based on embedding_dim
// Smaller embedding_dim benefits from larger block size (better occupancy)
// Larger embedding_dim benefits from smaller block size (more registers per thread)
size_t
block_size
=
256
;
// Default
if
(
_embedding_dim
<=
64
)
{
block_size
=
512
;
// Small embedding_dim: use larger block for better occupancy
}
else
if
(
_embedding_dim
>=
1024
)
{
block_size
=
128
;
// Large embedding_dim: use smaller block to reduce register pressure
}
size_t
grid_size
=
(
_num_indices
+
block_size
-
1
)
/
block_size
;
// Launch kernel based on dtypes
if
(
_input_dtype
==
INFINI_DTYPE_I32
)
{
const
int32_t
*
indices_ptr
=
reinterpret_cast
<
const
int32_t
*>
(
input
);
if
(
_weight_dtype
==
INFINI_DTYPE_F32
)
{
embeddingKernel
<
float
,
int32_t
><<<
grid_size
,
block_size
,
0
,
cuda_stream
>>>
(
reinterpret_cast
<
float
*>
(
output
),
indices_ptr
,
reinterpret_cast
<
const
float
*>
(
weight
),
_num_indices
,
_embedding_dim
,
_vocab_size
);
}
else
if
(
_weight_dtype
==
INFINI_DTYPE_F16
)
{
embeddingKernel
<
half
,
int32_t
><<<
grid_size
,
block_size
,
0
,
cuda_stream
>>>
(
reinterpret_cast
<
half
*>
(
output
),
indices_ptr
,
reinterpret_cast
<
const
half
*>
(
weight
),
_num_indices
,
_embedding_dim
,
_vocab_size
);
}
else
if
(
_weight_dtype
==
INFINI_DTYPE_BF16
)
{
embeddingKernel
<
cuda_bfloat16
,
int32_t
><<<
grid_size
,
block_size
,
0
,
cuda_stream
>>>
(
reinterpret_cast
<
cuda_bfloat16
*>
(
output
),
indices_ptr
,
reinterpret_cast
<
const
cuda_bfloat16
*>
(
weight
),
_num_indices
,
_embedding_dim
,
_vocab_size
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
_input_dtype
==
INFINI_DTYPE_I64
)
{
const
int64_t
*
indices_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input
);
if
(
_weight_dtype
==
INFINI_DTYPE_F32
)
{
embeddingKernel
<
float
,
int64_t
><<<
grid_size
,
block_size
,
0
,
cuda_stream
>>>
(
reinterpret_cast
<
float
*>
(
output
),
indices_ptr
,
reinterpret_cast
<
const
float
*>
(
weight
),
_num_indices
,
_embedding_dim
,
_vocab_size
);
}
else
if
(
_weight_dtype
==
INFINI_DTYPE_F16
)
{
embeddingKernel
<
half
,
int64_t
><<<
grid_size
,
block_size
,
0
,
cuda_stream
>>>
(
reinterpret_cast
<
half
*>
(
output
),
indices_ptr
,
reinterpret_cast
<
const
half
*>
(
weight
),
_num_indices
,
_embedding_dim
,
_vocab_size
);
}
else
if
(
_weight_dtype
==
INFINI_DTYPE_BF16
)
{
embeddingKernel
<
cuda_bfloat16
,
int64_t
><<<
grid_size
,
block_size
,
0
,
cuda_stream
>>>
(
reinterpret_cast
<
cuda_bfloat16
*>
(
output
),
indices_ptr
,
reinterpret_cast
<
const
cuda_bfloat16
*>
(
weight
),
_num_indices
,
_embedding_dim
,
_vocab_size
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
// Check for kernel launch errors
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
return
INFINI_STATUS_INTERNAL_ERROR
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::embedding::nvidia
src/infiniop/ops/embedding/nvidia/embedding_nvidia.cuh
0 → 100644
View file @
8d09630a
#ifndef __EMBEDDING_CUDA_H__
#define __EMBEDDING_CUDA_H__
#include "../embedding.h"
DESCRIPTOR
(
nvidia
)
#endif // __EMBEDDING_CUDA_H__
src/infiniop/ops/embedding/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/embedding.h"
#ifdef ENABLE_CPU_API
#include "cpu/embedding_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/embedding_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/embedding_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/embedding_moore.h"
#endif
__C
infiniStatus_t
infiniopCreateEmbeddingDescriptor
(
infiniopHandle_t
handle
,
infiniopEmbeddingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
output_desc
,
infiniopTensorDescriptor_t
input_desc
,
infiniopTensorDescriptor_t
weight_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::embedding::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::embedding::NAMESPACE::Descriptor **>(desc_ptr), \
output_desc, \
input_desc, \
weight_desc)
switch
(
handle
->
device
)
{
#ifdef ENABLE_CPU_API
CREATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopEmbedding
(
infiniopEmbeddingDescriptor_t
desc
,
void
*
output
,
const
void
*
input
,
const
void
*
weight
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::embedding::NAMESPACE::Descriptor *>(desc) \
->calculate(output, input, weight, stream)
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
CALCULATE
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyEmbeddingDescriptor
(
infiniopEmbeddingDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::embedding::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_CPU_API
DESTROY
(
INFINI_DEVICE_CPU
,
cpu
);
#endif
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_HYGON_API
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DESTROY
(
INFINI_DEVICE_METAX
,
metax
);
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/flash_attention/ninetoothed/build.py
0 → 100644
View file @
8d09630a
import
ninetoothed
from
.
import
flash_attention
from
.flash_attention
import
CausalVariant
import
infiniop.ninetoothed.build
import
torch
import
os
def
build
():
env_vars_to_check
=
[
"MACA_HOME"
,
"MACA_PATH"
,
"MACA_ROOT"
]
if
any
(
var
in
os
.
environ
for
var
in
env_vars_to_check
):
return
with_kv_cache_values
=
(
0
,)
emb_dim_values
=
(
16
,
32
,
64
,
128
,
256
)
is_causal_values
=
(
0
,
1
)
with_attn_mask_values
=
(
0
,)
causal_variant_values
=
(
CausalVariant
.
UPPER_LEFT
,
CausalVariant
.
LOWER_RIGHT
)
dtype_values
=
(
ninetoothed
.
float16
,
ninetoothed
.
bfloat16
,
ninetoothed
.
float32
)
block_size_m_values
=
(
256
,)
block_size_n_values
=
(
64
,)
constexpr_param_grid
=
{
"with_kv_cache"
:
with_kv_cache_values
,
"emb_dim"
:
emb_dim_values
,
"is_causal"
:
is_causal_values
,
"with_attn_mask"
:
with_attn_mask_values
,
"causal_variant"
:
causal_variant_values
,
"dtype"
:
dtype_values
,
"block_size_m"
:
block_size_m_values
,
"block_size_n"
:
block_size_n_values
,
}
infiniop
.
ninetoothed
.
build
.
build
(
flash_attention
.
premake
,
constexpr_param_grid
,
caller
=
"cuda"
,
op_name
=
"flash_attention"
,
output_dir
=
infiniop
.
ninetoothed
.
build
.
BUILD_DIRECTORY_PATH
,
)
src/infiniop/ops/flash_attention/ninetoothed/descriptor.h
0 → 100644
View file @
8d09630a
#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__
#define __FLASH_ATTENTION_DESCRIPTOR_H__
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/flash_attention.h"
#include "../../../ninetoothed/utils.h"
namespace
op
::
flash_attention
::
ninetoothed
{
class
Descriptor
final
:
public
InfiniopDescriptor
{
public:
Descriptor
(
infiniopHandle_t
handle
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
double
scale
,
char
is_causal
)
:
InfiniopDescriptor
{
handle
->
device
,
handle
->
device_id
},
_query_shape
{
q_desc
->
shape
()},
_query_strides
{
q_desc
->
strides
()},
_key_shape
{
k_desc
->
shape
()},
_key_strides
{
k_desc
->
strides
()},
_value_shape
{
v_desc
->
shape
()},
_value_strides
{
v_desc
->
strides
()},
_total_kv_shape
{
total_kv_len
->
shape
()},
_total_kv_strides
{
total_kv_len
->
strides
()},
_output_strides
{
out_desc
->
strides
()},
_dtype
{
q_desc
->
dtype
()},
_scale
{
scale
},
_is_causal
{
is_causal
}
{
}
~
Descriptor
()
=
default
;
size_t
get_workspace_size
()
const
{
return
0
;
}
infiniStatus_t
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k
,
const
void
*
v
,
const
void
*
total_kv_len
,
void
*
stream
)
const
{
uint64_t
empty_shape
[
4
];
int64_t
empty_strides
[
4
];
auto
query
{
::
ninetoothed
::
Tensor
{
q
,
_query_shape
,
_query_strides
}};
auto
key
{
::
ninetoothed
::
Tensor
{
k
,
_key_shape
,
_key_strides
}};
auto
value
{
::
ninetoothed
::
Tensor
{
v
,
_value_shape
,
_value_strides
}};
auto
total_kv_length
{
::
ninetoothed
::
Tensor
{
total_kv_len
,
_total_kv_shape
,
_total_kv_strides
}};
NineToothedTensor
attn_mask
{
nullptr
,
empty_shape
,
empty_strides
};
NineToothedTensor
is_causal
;
NineToothedTensor
scale
{
const_cast
<
double
*>
(
&
_scale
),
nullptr
,
nullptr
};
auto
output
{
::
ninetoothed
::
Tensor
{
out
,
_query_shape
,
_output_strides
}};
NineToothedTensor
with_attn_mask
;
NineToothedTensor
causal_variant
;
const
auto
with_kv_cache_
{
0
};
const
auto
emb_dim_
{
_query_shape
[
3
]};
const
auto
is_causal_
{
_is_causal
};
const
auto
with_attn_mask_
{
0
};
const
auto
causal_variant_
{
2
};
const
auto
dtype_
{
_dtype
};
constexpr
auto
block_size_m_
{
256
};
constexpr
auto
block_size_n_
{
64
};
if
(
launch_flash_attention
(
stream
,
query
,
key
,
value
,
total_kv_length
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
with_kv_cache_
,
emb_dim_
,
is_causal_
,
with_attn_mask_
,
causal_variant_
,
dtype_
,
block_size_m_
,
block_size_n_
))
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
double
scale
,
char
is_causal
)
{
*
desc
=
new
Descriptor
{
handle
,
out_desc
,
q_desc
,
k_desc
,
v_desc
,
total_kv_len
,
scale
,
is_causal
};
return
INFINI_STATUS_SUCCESS
;
}
private:
using
Size
=
::
ninetoothed
::
Tensor
<>::
Size
;
using
Stride
=
::
ninetoothed
::
Tensor
<>::
Stride
;
std
::
vector
<
Size
>
_query_shape
;
std
::
vector
<
Stride
>
_query_strides
;
std
::
vector
<
Size
>
_key_shape
;
std
::
vector
<
Stride
>
_key_strides
;
std
::
vector
<
Size
>
_value_shape
;
std
::
vector
<
Stride
>
_value_strides
;
std
::
vector
<
Size
>
_total_kv_shape
;
std
::
vector
<
Stride
>
_total_kv_strides
;
std
::
vector
<
Stride
>
_output_strides
;
infiniDtype_t
_dtype
;
double
_scale
;
char
_is_causal
;
};
}
// namespace op::flash_attention::ninetoothed
#endif // __FLASH_ATTENTION_DESCRIPTOR_H__
src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py
0 → 100644
View file @
8d09630a
import
enum
import
functools
import
ninetoothed
import
ninetoothed.language
as
ntl
from
ninetoothed
import
Tensor
BLOCK_SIZE_M
=
ninetoothed
.
block_size
()
BLOCK_SIZE_N
=
ninetoothed
.
block_size
()
class
CausalVariant
(
enum
.
IntEnum
):
"""Please refer to `<https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.bias.CausalVariant.html>`_."""
UPPER_LEFT
=
enum
.
auto
()
LOWER_RIGHT
=
enum
.
auto
()
def
arrangement
(
query
,
key
,
value
,
total_kv_len
,
present_key
,
present_value
,
present_key_slot
,
present_value_slot
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
with_kv_cache
,
block_size_m
=
None
,
block_size_n
=
None
,
):
def
arrange_query_or_output
(
input
):
arranged
=
input
.
tile
((
1
,
1
,
block_size_m
,
-
1
)).
tile
(
(
1
,
query
.
shape
[
-
3
]
//
key
.
shape
[
-
3
],
1
,
1
)
)
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
2
,
3
))
arranged
.
dtype
.
dtype
=
arranged
.
dtype
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
def
arrange_key_or_value
(
input
):
arranged
=
(
input
.
tile
((
1
,
1
,
block_size_n
,
-
1
))
.
tile
((
1
,
1
,
-
1
,
-
1
))
.
expand
((
-
1
,
-
1
,
query_arranged
.
shape
[
-
2
],
-
1
))
)
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
1
,
3
))
arranged
.
dtype
.
dtype
=
arranged
.
dtype
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
def
arrange_total_kv_len
(
input
,
shape
):
arranged
=
input
.
tile
((
1
,))
arranged
=
arranged
.
unsqueeze
(
1
).
unsqueeze
(
2
).
unsqueeze
(
3
).
expand
(
shape
)
return
arranged
def
arrange_present_key_or_present_value
(
input
):
arranged
=
input
.
tile
((
1
,
1
,
block_size_m
,
block_size_n
))
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
def
arrange_attn_mask
(
input
):
arranged
=
input
.
tile
((
1
,
1
,
block_size_m
,
block_size_n
)).
tile
((
1
,
1
,
1
,
-
1
))
arranged
.
dtype
=
arranged
.
dtype
.
squeeze
((
0
,
1
,
2
))
arranged
.
dtype
.
dtype
=
arranged
.
dtype
.
dtype
.
squeeze
((
0
,
1
))
return
arranged
if
block_size_m
is
None
:
block_size_m
=
BLOCK_SIZE_M
if
block_size_n
is
None
:
block_size_n
=
BLOCK_SIZE_N
query_arranged
=
arrange_query_or_output
(
query
)
key_arranged
=
arrange_key_or_value
(
key
)
value_arranged
=
arrange_key_or_value
(
value
)
total_kv_len_arranged
=
arrange_total_kv_len
(
total_kv_len
,
query_arranged
.
shape
)
present_key_arranged
=
arrange_present_key_or_present_value
(
present_key
)
present_value_arranged
=
arrange_present_key_or_present_value
(
present_value
)
present_key_slot_arranged
=
arrange_present_key_or_present_value
(
present_key_slot
)
present_value_slot_arranged
=
arrange_present_key_or_present_value
(
present_value_slot
)
attn_mask_arranged
=
arrange_attn_mask
(
attn_mask
)
is_causal_arranged
=
is_causal
scale_arranged
=
scale
output_arranged
=
arrange_query_or_output
(
output
)
with_attn_mask_arranged
=
with_attn_mask
causal_variant_arranged
=
causal_variant
if
with_kv_cache
:
return
(
query_arranged
,
key_arranged
,
value_arranged
,
total_kv_len_arranged
,
present_key_arranged
,
present_value_arranged
,
present_key_slot_arranged
,
present_value_slot_arranged
,
attn_mask_arranged
,
is_causal_arranged
,
scale_arranged
,
output_arranged
,
with_attn_mask_arranged
,
causal_variant_arranged
,
)
return
(
query_arranged
,
key_arranged
,
value_arranged
,
total_kv_len_arranged
,
attn_mask_arranged
,
is_causal_arranged
,
scale_arranged
,
output_arranged
,
with_attn_mask_arranged
,
causal_variant_arranged
,
)
def
application_with_kv_cache
(
query
,
key
,
value
,
total_kv_len
,
present_key
,
present_value
,
present_key_slot
,
present_value_slot
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
):
present_key_slot
=
present_key
# noqa: F841
present_value_slot
=
present_value
# noqa: F841
application_without_kv_cache
(
query
,
key
,
value
,
total_kv_len
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
)
def
application_without_kv_cache
(
query
,
key
,
value
,
total_kv_len
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
):
actual_kv_len
=
total_kv_len
[
0
]
for
i
in
range
(
query
.
shape
[
0
]):
query_i
=
(
1.4426950408889634
*
scale
*
query
[
i
]).
to
(
query
[
i
].
dtype
)
acc
=
ntl
.
zeros
((
query_i
.
shape
[
-
2
],
query_i
.
shape
[
-
1
]),
dtype
=
ntl
.
float32
)
lse
=
ntl
.
full
((
query_i
.
shape
[
-
2
],),
1
,
dtype
=
ntl
.
float32
)
max
=
ntl
.
full
((
query_i
.
shape
[
-
2
],),
float
(
"-inf"
),
dtype
=
ntl
.
float32
)
for
j
in
range
(
-
(
-
actual_kv_len
//
key
.
dtype
.
shape
[
0
])):
qk
=
ntl
.
dot
(
query_i
,
ntl
.
trans
(
key
[
j
]))
key_pos
=
key
[
j
].
offsets
(
-
2
)
qk
=
ntl
.
where
(
key_pos
<
actual_kv_len
,
qk
,
float
(
"-inf"
))
if
with_attn_mask
:
qk
+=
attn_mask
[
j
]
if
is_causal
:
query_pos
=
query
[
i
].
offsets
(
-
2
)
if
causal_variant
==
2
:
# CausalVariant.LOWER_RIGHT:
mask
=
(
query_pos
[:,
None
]
+
actual_kv_len
-
query
.
source
.
shape
[
-
2
]
>=
key_pos
[
None
,
:]
)
else
:
mask
=
query_pos
[:,
None
]
>=
key_pos
[
None
,
:]
qk
=
ntl
.
where
(
mask
,
qk
,
float
(
"-inf"
))
next_max
=
ntl
.
maximum
(
max
,
ntl
.
max
(
qk
,
1
))
stable_qk
=
ntl
.
exp2
(
qk
-
next_max
[:,
None
])
alpha
=
ntl
.
exp2
(
max
-
next_max
)
acc
=
acc
*
alpha
[:,
None
]
+
ntl
.
dot
(
stable_qk
.
to
(
value
[
i
].
dtype
),
value
[
j
])
max
=
next_max
lse
=
lse
*
alpha
+
ntl
.
sum
(
stable_qk
,
1
)
acc
/=
lse
[:,
None
]
output
[
i
]
=
acc
# noqa: F841
def
premake
(
with_kv_cache
,
emb_dim
=
None
,
is_causal
=
None
,
with_attn_mask
=
None
,
causal_variant
=
None
,
dtype
=
None
,
block_size_m
=
None
,
block_size_n
=
None
,
):
arrangement_
=
functools
.
partial
(
arrangement
,
with_kv_cache
=
with_kv_cache
,
block_size_m
=
block_size_m
,
block_size_n
=
block_size_n
,
)
query
,
key
,
value
,
attn_mask
,
output
=
(
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
(
None
,
None
,
None
,
{
"constexpr"
:
True
,
"upper_bound"
:
128
}),
)
for
_
in
range
(
5
)
)
total_kv_len
=
Tensor
(
1
,
dtype
=
ninetoothed
.
int32
)
present_key
,
present_value
,
present_key_slot
,
present_value_slot
=
(
Tensor
(
4
,
dtype
=
dtype
)
for
_
in
range
(
4
)
)
scale
=
Tensor
(
0
,
dtype
=
ninetoothed
.
float64
)
is_causal
=
Tensor
(
0
,
constexpr
=
True
,
value
=
is_causal
)
with_attn_mask
=
Tensor
(
0
,
constexpr
=
True
,
value
=
with_attn_mask
)
causal_variant
=
Tensor
(
0
,
constexpr
=
True
,
value
=
causal_variant
)
if
emb_dim
is
not
None
:
for
tensor
in
(
query
,
key
,
value
,
attn_mask
,
output
):
tensor
.
shape
=
tensor
.
shape
[:
-
1
]
+
(
emb_dim
,)
if
with_kv_cache
:
application
=
application_with_kv_cache
else
:
application
=
application_without_kv_cache
tensors
=
(
query
,
key
,
value
,
total_kv_len
,
present_key
,
present_value
,
present_key_slot
,
present_value_slot
,
attn_mask
,
is_causal
,
scale
,
output
,
with_attn_mask
,
causal_variant
,
)
return
arrangement_
,
application
,
tensors
src/infiniop/ops/flash_attention/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/flash_attention.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
#include "ninetoothed/descriptor.h"
#endif
#endif
__C
infiniStatus_t
infiniopCreateFlashAttentionDescriptor
(
infiniopHandle_t
handle
,
infiniopFlashAttentionDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
q_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
total_kv_len
,
float
scale
,
char
is_causal
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::flash_attention::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
q_desc, \
k_desc, \
v_desc, \
total_kv_len, \
scale, \
is_causal);
switch
(
handle
->
device
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetFlashAttentionWorkspaceSize
(
infiniopFlashAttentionDescriptor_t
desc
,
size_t
*
size
)
{
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET_SIZE
}
__C
infiniStatus_t
infiniopFlashAttention
(
infiniopFlashAttentionDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
q
,
const
void
*
k
,
const
void
*
v
,
const
void
*
total_kv_len
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::flash_attention::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream);
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyFlashAttentionDescriptor
(
infiniopFlashAttentionDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::flash_attention::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
}
src/infiniop/ops/gelu/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gelu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/gelu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -49,6 +49,9 @@ __C infiniStatus_t infiniopCreateGeluDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -83,6 +86,10 @@ __C infiniStatus_t infiniopGetGeluWorkspaceSize(infiniopGeluDescriptor_t desc, s
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -124,6 +131,9 @@ __C infiniStatus_t infiniopGelu(
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -160,6 +170,9 @@ infiniopDestroyGeluDescriptor(infiniopGeluDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/gemm/bang/gemm_bang.cc
View file @
8d09630a
...
...
@@ -15,8 +15,8 @@ struct Descriptor::Opaque {
cnnlDestroyTensorDescriptor
(
a
);
cnnlDestroyTensorDescriptor
(
b
);
cnnlDestroyTensorDescriptor
(
c
);
cnnlMatMulDesc
Destroy
(
op
);
cnnlMatMulAlgo
Destroy
(
algo
);
cnnl
Destroy
MatMulDesc
riptor
(
op
);
cnnl
Destroy
MatMulAlgo
(
algo
);
cnnlDestroyMatMulHeuristicResult
(
algoResult
);
}
};
...
...
@@ -85,8 +85,8 @@ infiniStatus_t Descriptor::create(
cnnlMatMulDescriptor_t
op
;
cnnlMatMulAlgo_t
algo
;
cnnlMatMulHeuristicResult_t
algoResult
;
CHECK_BANG
(
cnnlMatMulDesc
Create
(
&
op
));
CHECK_BANG
(
cnnlMatMulAlgo
Create
(
&
algo
));
CHECK_BANG
(
cnnl
Create
MatMulDesc
riptor
(
&
op
));
CHECK_BANG
(
cnnl
Create
MatMulAlgo
(
&
algo
));
CHECK_BANG
(
cnnlCreateMatMulHeuristicResult
(
&
algoResult
));
int32_t
use_stride
=
true
;
CHECK_BANG
(
cnnlSetMatMulDescAttr
(
...
...
@@ -101,7 +101,7 @@ infiniStatus_t Descriptor::create(
(
cnrtQueue_t
)
nullptr
,
[
&
](
cnnlHandle_t
_handle
)
{
CHECK_BANG
(
cnnlGetBatchMatMulAlgoHeuristic
(
cnnlGetBatchMatMul
Ex
AlgoHeuristic
(
_handle
,
op
,
a
,
b
,
c
,
NULL
,
1
,
&
algoResult
,
&
count
));
...
...
@@ -109,7 +109,7 @@ infiniStatus_t Descriptor::create(
}));
size_t
workspace_size
;
CHECK_BANG
(
cnnlGetBatchMatMulHeuristicResult
(
algoResult
,
algo
,
&
workspace_size
));
CHECK_BANG
(
cnnlGetBatchMatMul
Ex
HeuristicResult
(
algoResult
,
algo
,
&
workspace_size
));
*
desc_ptr
=
new
Descriptor
(
dtype
,
info
,
workspace_size
,
...
...
@@ -135,7 +135,7 @@ infiniStatus_t Descriptor::calculate(
CHECK_STATUS
(
_opaque
->
internal
->
useCnnl
(
(
cnrtQueue_t
)
stream
,
[
&
](
cnnlHandle_t
handle
)
{
CHECK_BANG
(
cnnlBatchMatMul
BCast_v2
(
CHECK_BANG
(
cnnlBatchMatMul
Ex
(
handle
,
_opaque
->
op
,
_opaque
->
algo
,
...
...
src/infiniop/ops/gemm/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/gemm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/gemm_nvidia.cuh"
#endif
#ifdef ENABLE_CAMBRICON_API
...
...
@@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateGemmDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -102,6 +105,9 @@ infiniopGetGemmWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -160,6 +166,9 @@ __C infiniStatus_t infiniopGemm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -208,6 +217,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
src/infiniop/ops/kv_caching/ninetoothed/build.py
0 → 100644
View file @
8d09630a
import
ninetoothed
from
.
import
kv_caching
import
infiniop.ninetoothed.build
def
build
():
dtype_values
=
(
ninetoothed
.
float16
,
ninetoothed
.
bfloat16
,
ninetoothed
.
float32
,
)
constexpr_param_grid
=
{
"emb_dim"
:
(
1
,
16
,
32
,
64
,
128
,
256
),
"dtype"
:
dtype_values
,
"block_size_m"
:
(
64
,),
"block_size_n"
:
(
64
,),
}
infiniop
.
ninetoothed
.
build
.
build
(
kv_caching
.
premake
,
constexpr_param_grid
,
caller
=
"cuda"
,
op_name
=
"kv_caching"
,
output_dir
=
infiniop
.
ninetoothed
.
build
.
BUILD_DIRECTORY_PATH
,
)
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h
0 → 100644
View file @
8d09630a
#ifndef KV_CACHING_H
#define KV_CACHING_H
#include "../../../handle.h"
#include "../../../operator.h"
#include "../../../tensor.h"
#include "../../../../../build/ninetoothed/kv_caching.h"
#include "../../../ninetoothed/utils.h"
namespace
op
::
kv_caching
::
ninetoothed
{
class
Descriptor
final
:
public
InfiniopDescriptor
{
public:
Descriptor
(
infiniopHandle_t
handle
,
infiniopTensorDescriptor_t
k_cache_desc
,
infiniopTensorDescriptor_t
v_cache_desc
,
infiniopTensorDescriptor_t
k_desc
,
infiniopTensorDescriptor_t
v_desc
,
infiniopTensorDescriptor_t
past_kv_lengths_desc
)
:
InfiniopDescriptor
{
handle
->
device
,
handle
->
device_id
},
k_cache_shape_
{
k_cache_desc
->
shape
()},
k_cache_strides_
{
k_cache_desc
->
strides
()},
v_cache_shape_
{
v_cache_desc
->
shape
()},
v_cache_strides_
{
v_cache_desc
->
strides
()},
k_shape_
{
k_desc
->
shape
()},
k_strides_
{
k_desc
->
strides
()},
v_shape_
{
v_desc
->
shape
()},
v_strides_
{
v_desc
->
strides
()},
past_kv_lengths_shape_
{
past_kv_lengths_desc
->
shape
()},
past_kv_lengths_strides_
{
past_kv_lengths_desc
->
strides
()},
dtype_
{
k_desc
->
dtype
()}
{}
~
Descriptor
()
=
default
;
size_t
get_workspace_size
()
const
{
return
0
;
};
static
infiniStatus_t
create
(
infiniopHandle_t
handle
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
*
desc_ptr
=
new
Descriptor
{
handle
,
k_cache
,
v_cache
,
k
,
v
,
past_kv_lengths
};
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream
)
const
{
auto
k_cache_nt
{
::
ninetoothed
::
Tensor
{
k_cache
,
k_cache_shape_
,
k_cache_strides_
}};
auto
v_cache_nt
{
::
ninetoothed
::
Tensor
{
v_cache
,
v_cache_shape_
,
v_cache_strides_
}};
auto
k_nt
{
::
ninetoothed
::
Tensor
{
k
,
k_shape_
,
k_strides_
}};
auto
v_nt
{
::
ninetoothed
::
Tensor
{
v
,
v_shape_
,
v_strides_
}};
auto
past_kv_lengths_nt
{
::
ninetoothed
::
Tensor
{
past_kv_lengths
,
past_kv_lengths_shape_
,
past_kv_lengths_strides_
}};
if
(
launch_kv_caching
(
stream
,
k_cache_nt
,
v_cache_nt
,
k_nt
,
v_nt
,
past_kv_lengths_nt
,
k_shape_
[
3
],
dtype_
,
64
,
64
))
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
private:
using
Size
=
::
ninetoothed
::
Tensor
<>::
Size
;
using
Stride
=
::
ninetoothed
::
Tensor
<>::
Stride
;
std
::
vector
<
Size
>
k_cache_shape_
;
std
::
vector
<
Stride
>
k_cache_strides_
;
std
::
vector
<
Size
>
v_cache_shape_
;
std
::
vector
<
Stride
>
v_cache_strides_
;
std
::
vector
<
Size
>
k_shape_
;
std
::
vector
<
Stride
>
k_strides_
;
std
::
vector
<
Size
>
v_shape_
;
std
::
vector
<
Stride
>
v_strides_
;
std
::
vector
<
Size
>
past_kv_lengths_shape_
;
std
::
vector
<
Stride
>
past_kv_lengths_strides_
;
infiniDtype_t
dtype_
;
};
}
// namespace op::kv_caching::ninetoothed
#endif // KV_CACHING_H
src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py
0 → 100644
View file @
8d09630a
import
functools
import
ninetoothed
from
ninetoothed
import
Tensor
def
arrangement
(
k_cache
,
v_cache
,
k
,
v
,
past_lengths
,
block_size_m
=
ninetoothed
.
block_size
(),
block_size_n
=
ninetoothed
.
block_size
(),
):
k_cache_arranged
=
k_cache
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
v_cache_arranged
=
v_cache
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
k_arranged
=
k
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
v_arranged
=
v
.
tile
((
1
,
block_size_m
,
1
,
-
1
)).
tile
((
1
,
1
,
-
1
,
1
))
past_lengths_arranged
=
(
past_lengths
.
tile
((
1
,))
.
unsqueeze
(
1
)
.
unsqueeze
(
2
)
.
unsqueeze
(
3
)
.
unsqueeze
(
4
)
.
expand
((
-
1
,
*
k_arranged
.
shape
))
)
return
(
k_cache_arranged
,
v_cache_arranged
,
k_arranged
,
v_arranged
,
past_lengths_arranged
,
)
def
application
(
k_cache
,
v_cache
,
k
,
v
,
past_lengths
):
pos
=
past_lengths
for
i
in
range
(
k
.
shape
[
-
2
]):
k_cache
[
0
,
0
,
pos
+
i
,
0
]
=
k
[
0
,
0
,
i
,
0
]
v_cache
[
0
,
0
,
pos
+
i
,
0
]
=
v
[
0
,
0
,
i
,
0
]
def
premake
(
emb_dim
=
None
,
dtype
=
None
,
block_size_m
=
None
,
block_size_n
=
None
):
arrangement_
=
functools
.
partial
(
arrangement
,
block_size_m
=
block_size_m
,
block_size_n
=
block_size_n
)
shape_options
=
(
None
,
None
,
None
,
{
"constexpr"
:
True
,
"upper_bound"
:
256
})
tensors
=
(
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
4
,
dtype
=
dtype
,
shape_options
=
shape_options
),
Tensor
(
1
,
dtype
=
ninetoothed
.
int64
),
)
if
emb_dim
is
not
None
:
for
tensor
in
tensors
:
tensor
.
shape
=
tensor
.
shape
[:
-
1
]
+
(
emb_dim
,)
return
arrangement_
,
application
,
tensors
src/infiniop/ops/kv_caching/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/kv_caching.h"
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API)
#include "ninetoothed/kv_caching.h"
#endif
#endif
__C
infiniStatus_t
infiniopCreateKVCachingDescriptor
(
infiniopHandle_t
handle
,
infiniopKVCachingDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
k_cache
,
infiniopTensorDescriptor_t
v_cache
,
infiniopTensorDescriptor_t
k
,
infiniopTensorDescriptor_t
v
,
infiniopTensorDescriptor_t
past_kv_lengths
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::kv_caching::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor **>(desc_ptr), \
k_cache, \
v_cache, \
k, \
v, \
past_kv_lengths)
switch
(
handle
->
device
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
CREATE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
CREATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetKVCachingWorkspaceSize
(
infiniopKVCachingDescriptor_t
desc
,
size_t
*
size
)
{
#define GET_SIZE(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->get_workspace_size(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
GET_SIZE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
GET_SIZE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
GET_SIZE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET_SIZE
}
__C
infiniStatus_t
infiniopKVCaching
(
infiniopKVCachingDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
k_cache
,
void
*
v_cache
,
const
void
*
k
,
const
void
*
v
,
const
void
*
past_kv_lengths
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::kv_caching::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream)
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
CALCULATE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CALCULATE
}
__C
infiniStatus_t
infiniopDestroyKVCachingDescriptor
(
infiniopKVCachingDescriptor_t
desc
)
{
#define DELETE(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::kv_caching::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NINETOOTHED)
#if defined(ENABLE_NVIDIA_API)
DELETE
(
INFINI_DEVICE_NVIDIA
,
ninetoothed
);
#endif
#if defined(ENABLE_ILUVATAR_API)
DELETE
(
INFINI_DEVICE_ILUVATAR
,
ninetoothed
);
#endif
#if defined(ENABLE_METAX_API)
DELETE
(
INFINI_DEVICE_METAX
,
ninetoothed
);
#endif
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DELETE
}
src/infiniop/ops/layer_norm/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/layer_norm_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/layer_norm_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateLayerNormDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -76,6 +79,9 @@ __C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -126,6 +132,9 @@ __C infiniStatus_t infiniopLayerNorm(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
@@ -156,6 +165,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
...
...
src/infiniop/ops/logsoftmax/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/logsoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/logsoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateLogSoftmaxDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
// CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
...
...
@@ -66,6 +69,9 @@ __C infiniStatus_t infiniopGetLogSoftmaxWorkspaceSize(infiniopLogSoftmaxDescript
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
// GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
...
...
@@ -101,6 +107,9 @@ __C infiniStatus_t infiniopLogSoftmax(
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
...
...
@@ -131,6 +140,9 @@ __C infiniStatus_t infiniopDestroyLogSoftmaxDescriptor(infiniopLogSoftmaxDescrip
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_ILUVATAR_API
// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
...
...
src/infiniop/ops/lp_norm/operator.cc
View file @
8d09630a
...
...
@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/lp_norm.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/lp_norm_nvidia.cuh"
#endif
...
...
@@ -36,6 +36,9 @@ __C infiniStatus_t infiniopCreateLPNormDescriptor(
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -60,6 +63,9 @@ __C infiniStatus_t infiniopGetLPNormWorkspaceSize(infiniopLPNormDescriptor_t des
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -97,6 +103,9 @@ __C infiniStatus_t infiniopLPNorm(
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -124,6 +133,9 @@ infiniopDestroyLPNormDescriptor(infiniopLPNormDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/mul/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/mul_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/mul_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -48,6 +48,9 @@ __C infiniStatus_t infiniopCreateMulDescriptor(
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CREATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
@@ -85,6 +88,9 @@ __C infiniStatus_t infiniopGetMulWorkspaceSize(infiniopMulDescriptor_t desc, siz
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
GET
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
@@ -131,6 +137,9 @@ __C infiniStatus_t infiniopMul(
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
CALCULATE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
@@ -170,6 +179,9 @@ infiniopDestroyMulDescriptor(infiniopMulDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
#ifdef ENABLE_METAX_API
DELETE
(
INFINI_DEVICE_METAX
,
metax
);
#endif
...
...
src/infiniop/ops/ones/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/ones_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/ones_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -49,6 +49,10 @@ __C infiniStatus_t infiniopCreateOnesDescriptor(
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -82,6 +86,10 @@ __C infiniStatus_t infiniopGetOnesWorkspaceSize(infiniopOnesDescriptor_t desc, s
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -123,6 +131,10 @@ __C infiniStatus_t infiniopOnes(
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -158,6 +170,10 @@ infiniopDestroyOnesDescriptor(infiniopOnesDescriptor_t desc) {
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
src/infiniop/ops/paged_attention/cuda/kernel.cuh
0 → 100644
View file @
8d09630a
#ifndef __PAGED_ATTENTION_KERNEL_CUH__
#define __PAGED_ATTENTION_KERNEL_CUH__
// This kernel is refactored to be high-performance, adopting parallelism strategies
// from industry-standard implementations like vLLM. It fixes functional and performance
// issues in the original draft.
namespace
op
::
paged_attention
::
cuda
{
template
<
typename
Tdata
,
typename
Tcompute
,
size_t
HEAD_SIZE
,
size_t
NUM_THREADS
>
__device__
void
pagedAttentionKernel
(
Tdata
*
out_
,
const
Tdata
*
q_
,
const
Tdata
*
k_cache_
,
const
Tdata
*
v_cache_
,
const
int64_t
*
block_tables_
,
const
int64_t
*
seq_lens_
,
const
float
*
alibi_slopes_
,
const
size_t
num_kv_heads
,
const
float
scale
,
const
size_t
max_num_blocks_per_seq
,
const
size_t
block_size
,
const
ptrdiff_t
q_stride
,
const
ptrdiff_t
kv_block_stride
,
const
ptrdiff_t
kv_head_stride
,
const
ptrdiff_t
o_stride
)
{
//================================================================================
// 1. Setup & Query Loading (No changes in this section)
//================================================================================
const
int
seq_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int64_t
seq_len
=
seq_lens_
[
seq_idx
];
if
(
seq_len
==
0
)
{
return
;
}
const
size_t
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
size_t
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
(
alibi_slopes_
==
nullptr
)
?
0.0
f
:
alibi_slopes_
[
head_idx
];
const
int64_t
*
block_table
=
block_tables_
+
seq_idx
*
max_num_blocks_per_seq
;
const
Tdata
*
q_ptr
=
q_
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
Tdata
*
out_ptr
=
out_
+
seq_idx
*
o_stride
+
head_idx
*
HEAD_SIZE
;
extern
__shared__
char
shared_mem_char
[];
Tcompute
*
shared_mem
=
reinterpret_cast
<
Tcompute
*>
(
shared_mem_char
);
Tcompute
*
q_shared
=
shared_mem
;
Tcompute
*
logits
=
shared_mem
+
HEAD_SIZE
;
// printf("static_cast<Tcompute>(q_ptr[i]);");
for
(
size_t
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
q_shared
[
i
]
=
static_cast
<
Tcompute
>
(
q_ptr
[
i
]);
}
__syncthreads
();
//================================================================================
// 2. Compute QK Dot Product & Find Max Logit
//================================================================================
for
(
size_t
token_idx
=
threadIdx
.
x
;
token_idx
<
seq_len
;
token_idx
+=
NUM_THREADS
)
{
const
int64_t
block_idx
=
token_idx
/
block_size
;
const
int64_t
token_in_block_idx
=
token_idx
%
block_size
;
const
int64_t
physical_block_num
=
block_table
[
block_idx
];
const
Tdata
*
k_vec_ptr
=
k_cache_
+
physical_block_num
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
token_in_block_idx
*
HEAD_SIZE
;
Tcompute
qk
=
0.0
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
HEAD_SIZE
/
8
;
++
i
)
{
const
size_t
offset
=
i
*
8
;
// 手动展开8次计算
qk
+=
q_shared
[
offset
+
0
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
0
]);
qk
+=
q_shared
[
offset
+
1
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
1
]);
qk
+=
q_shared
[
offset
+
2
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
2
]);
qk
+=
q_shared
[
offset
+
3
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
3
]);
qk
+=
q_shared
[
offset
+
4
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
4
]);
qk
+=
q_shared
[
offset
+
5
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
5
]);
qk
+=
q_shared
[
offset
+
6
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
6
]);
qk
+=
q_shared
[
offset
+
7
]
*
static_cast
<
Tcompute
>
(
k_vec_ptr
[
offset
+
7
]);
}
qk
*=
scale
;
if
(
alibi_slope
!=
0.0
f
)
{
qk
+=
alibi_slope
*
(
token_idx
-
seq_len
+
1
);
}
logits
[
token_idx
]
=
qk
;
}
__syncthreads
();
__shared__
Tcompute
global_qk_max
;
Tcompute
global_qk_max_0
=
op
::
common_cuda
::
reduce_op
::
max
<
NUM_THREADS
,
Tcompute
>
(
logits
,
seq_len
);
if
(
threadIdx
.
x
==
0
)
{
global_qk_max
=
global_qk_max_0
;
}
__syncthreads
();
//================================================================================
// 3. Compute Softmax (No changes in this section)
//================================================================================
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_len
;
i
+=
NUM_THREADS
)
{
Tcompute
val
=
expf
(
logits
[
i
]
-
global_qk_max
);
// 使用全局最大值
logits
[
i
]
=
val
;
}
__syncthreads
();
__shared__
Tcompute
inv_sum
;
Tcompute
exp_sum_0
=
op
::
common_cuda
::
reduce_op
::
sum
<
NUM_THREADS
,
Tcompute
,
Tcompute
>
(
logits
,
seq_len
);
if
(
threadIdx
.
x
==
0
)
{
inv_sum
=
1.0
f
/
(
exp_sum_0
+
1e-6
f
);
}
__syncthreads
();
for
(
size_t
i
=
threadIdx
.
x
;
i
<
seq_len
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
}
__syncthreads
();
//================================================================================
// 4. Aggregate Values (V) weighted by probabilities
//================================================================================
for
(
size_t
h_dim
=
threadIdx
.
x
;
h_dim
<
HEAD_SIZE
;
h_dim
+=
NUM_THREADS
)
{
Tcompute
acc
=
0.0
f
;
for
(
size_t
token_idx
=
0
;
token_idx
<
seq_len
;
++
token_idx
)
{
const
size_t
block_idx
=
token_idx
/
block_size
;
const
size_t
token_in_block_idx
=
token_idx
%
block_size
;
const
int64_t
physical_block_num
=
block_table
[
block_idx
];
const
Tcompute
prob
=
logits
[
token_idx
];
const
Tdata
*
v_vec_ptr
=
v_cache_
+
physical_block_num
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
token_in_block_idx
*
HEAD_SIZE
;
const
Tdata
v_val
=
v_vec_ptr
[
h_dim
];
acc
+=
prob
*
static_cast
<
Tcompute
>
(
v_val
);
}
out_ptr
[
h_dim
]
=
static_cast
<
Tdata
>
(
acc
);
}
}
}
// namespace op::paged_attention::cuda
#endif // __PAGED_ATTENTION_KERNEL_CUH__
Prev
1
…
7
8
9
10
11
12
13
14
15
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment