Commit 835209e7 authored by wooway777's avatar wooway777
Browse files

issue/900 - support embedding on iluvatar, metax, and moore

parent cc2cc3a1
#ifndef __EMBEDDING_CUDA_KERNEL_CUH__ #ifndef __EMBEDDING_CUDA_KERNEL_CUH__
#define __EMBEDDING_CUDA_KERNEL_CUH__ #define __EMBEDDING_CUDA_KERNEL_CUH__
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <type_traits> #include <type_traits>
namespace op::embedding::nvidia {
// Helper function to check memory alignment // Helper function to check memory alignment
__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) { __forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) {
// Use size_t for pointer arithmetic in device code (more compatible) // Use size_t for pointer arithmetic in device code (more compatible)
...@@ -118,61 +113,4 @@ __forceinline__ __device__ void copyScalar( ...@@ -118,61 +113,4 @@ __forceinline__ __device__ void copyScalar(
} }
} }
template <typename T, typename IndexType>
INFINIOP_CUDA_KERNEL embeddingKernel(
T *__restrict__ output,
const IndexType *__restrict__ indices,
const T *__restrict__ weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
// Calculate global thread index
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_indices) {
// Get the index value
IndexType index_val = __ldg(&indices[idx]);
// Bounds check - handle negative indices gracefully
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
// Copy embedding vector from weight to output
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
T *dst = output + idx * embedding_dim;
// Choose optimal copy strategy based on type and alignment
if constexpr (std::is_same_v<T, float>) {
// Check alignment for float4 (16 bytes)
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
// Try float2 if not aligned to 16 bytes
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, half>) {
// Use half2 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// Use bfloat162 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else {
// Fallback to scalar copy with __ldg
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
}
}
}
} // namespace op::embedding::nvidia
#endif // __EMBEDDING_CUDA_KERNEL_CUH__ #endif // __EMBEDDING_CUDA_KERNEL_CUH__
#ifndef __EMBEDDING_METAX_H__
#define __EMBEDDING_METAX_H__
#include "../embedding.h"
DESCRIPTOR(metax)
#endif // __EMBEDDING_METAX_H__
#include "../../../../utils.h"
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../../../tensor.h"
#include "../cuda/embedding_kernel.cuh"
#include "embedding_metax.cuh"
template <typename T, typename IndexType>
INFINIOP_METAX_KERNEL embeddingKernel(
T *__restrict__ output,
const IndexType *__restrict__ indices,
const T *__restrict__ weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
// Calculate global thread index
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_indices) {
// Get the index value
IndexType index_val = __ldg(&indices[idx]);
// Bounds check - handle negative indices gracefully
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
// Copy embedding vector from weight to output
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
T *dst = output + idx * embedding_dim;
// Choose optimal copy strategy based on type and alignment
if constexpr (std::is_same_v<T, float>) {
// Check alignment for float4 (16 bytes)
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
// Try float2 if not aligned to 16 bytes
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, half>) {
// Use half2 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// Use bfloat162 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else {
// Fallback to scalar copy with __ldg
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
}
}
}
namespace op::embedding::metax {
struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc) {
auto input_shape = input_desc->shape();
auto weight_shape = weight_desc->shape();
// Validate shapes
CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Check output shape matches input shape + embedding_dim
auto output_shape = output_desc->shape();
size_t embedding_dim = weight_shape[1];
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);
for (size_t i = 0; i < input_shape.size(); ++i) {
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
}
// Validate dtypes
auto input_dtype = input_desc->dtype();
auto weight_dtype = weight_desc->dtype();
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 ||
weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
// Calculate number of indices (supporting batch dimension)
size_t num_indices = 1;
for (auto dim : input_shape) {
num_indices *= dim;
}
size_t vocab_size = weight_shape[0];
*desc_ptr = new Descriptor(
num_indices,
embedding_dim,
vocab_size,
input_dtype,
weight_dtype,
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *output,
const void *input,
const void *weight,
void *stream) const {
if (_num_indices == 0) {
return INFINI_STATUS_SUCCESS;
}
auto hc_stream = reinterpret_cast<hcStream_t>(stream);
// Dynamic block size optimization based on embedding_dim for Metax platform
size_t block_size = 256; // Default block size for Metax
if (_embedding_dim <= 64) {
block_size = 512; // Small embedding_dim: use larger block for better occupancy
} else if (_embedding_dim >= 1024) {
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
}
size_t grid_size = (_num_indices + block_size - 1) / block_size;
// Launch kernel based on dtypes for Metax platform
if (_input_dtype == INFINI_DTYPE_I32) {
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
if (_weight_dtype == INFINI_DTYPE_F32) {
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
reinterpret_cast<float *>(output),
indices_ptr,
reinterpret_cast<const float *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
reinterpret_cast<half *>(output),
indices_ptr,
reinterpret_cast<const half *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
// Use Metax's bfloat16 type
embeddingKernel<__hpcc_bfloat16, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
reinterpret_cast<__hpcc_bfloat16 *>(output),
indices_ptr,
reinterpret_cast<const __hpcc_bfloat16 *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_input_dtype == INFINI_DTYPE_I64) {
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
if (_weight_dtype == INFINI_DTYPE_F32) {
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
reinterpret_cast<float *>(output),
indices_ptr,
reinterpret_cast<const float *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
reinterpret_cast<half *>(output),
indices_ptr,
reinterpret_cast<const half *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
embeddingKernel<__hpcc_bfloat16, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
reinterpret_cast<__hpcc_bfloat16 *>(output),
indices_ptr,
reinterpret_cast<const __hpcc_bfloat16 *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::embedding::metax
#ifndef __EMBEDDING_MOORE_H__
#define __EMBEDDING_MOORE_H__
#include "../embedding.h"
DESCRIPTOR(moore)
#endif // __EMBEDDING_MOORE_H__
#include "../../../../utils.h"
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "../../../tensor.h"
#include "embedding_moore_kernel.h"
#include "embedding_moore.h"
#include <musa_runtime.h>
template <typename T, typename IndexType>
INFINIOP_MOORE_KERNEL embeddingKernel(
T *__restrict__ output,
const IndexType *__restrict__ indices,
const T *__restrict__ weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
// Calculate global thread index
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_indices) {
// Get the index value with Moore-optimized memory access
IndexType index_val = indices[idx];
// Bounds check - handle negative indices gracefully
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
// Copy embedding vector from weight to output
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
T *dst = output + idx * embedding_dim;
// Choose optimal copy strategy based on type and alignment
if constexpr (std::is_same_v<T, float>) {
// Check alignment for float4 (16 bytes)
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
// Try float2 if not aligned to 16 bytes
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, half>) {
// Use half2 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, __mt_bfloat16>) {
// Use mt_bfloat162 for vectorized access (Moore-specific type)
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else {
// Fallback to scalar copy with Moore-optimized memory access
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
}
}
}
namespace op::embedding::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t output_desc,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t weight_desc) {
auto input_shape = input_desc->shape();
auto weight_shape = weight_desc->shape();
// Validate shapes
CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
// Check output shape matches input shape + embedding_dim
auto output_shape = output_desc->shape();
size_t embedding_dim = weight_shape[1];
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);
for (size_t i = 0; i < input_shape.size(); ++i) {
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
}
// Validate dtypes
auto input_dtype = input_desc->dtype();
auto weight_dtype = weight_desc->dtype();
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
// Calculate number of indices (supporting batch dimension)
size_t num_indices = 1;
for (auto dim : input_shape) {
num_indices *= dim;
}
size_t vocab_size = weight_shape[0];
*desc_ptr = new Descriptor(
num_indices,
embedding_dim,
vocab_size,
input_dtype,
weight_dtype,
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
handle->device,
handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t Descriptor::calculate(
void *output,
const void *input,
const void *weight,
void *stream) const {
if (_num_indices == 0) {
return INFINI_STATUS_SUCCESS;
}
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// Dynamic block size optimization based on embedding_dim
// Moore platform typically has different performance characteristics
size_t block_size = 256; // Default for Moore
if (_embedding_dim <= 64) {
block_size = 512; // Small embedding_dim: use larger block for better occupancy
} else if (_embedding_dim >= 1024) {
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
} else if (_embedding_dim <= 256) {
block_size = 384; // Medium embedding_dim: balanced configuration
}
size_t grid_size = (_num_indices + block_size - 1) / block_size;
// Launch kernel based on dtypes
// Note: Moore uses __mt_bfloat16 instead of __nv_bfloat16
if (_input_dtype == INFINI_DTYPE_I32) {
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
if (_weight_dtype == INFINI_DTYPE_F32) {
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, musa_stream>>>(
reinterpret_cast<float *>(output),
indices_ptr,
reinterpret_cast<const float *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, musa_stream>>>(
reinterpret_cast<half *>(output),
indices_ptr,
reinterpret_cast<const half *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
// Use Moore's bfloat16 type
embeddingKernel<__mt_bfloat16, int32_t><<<grid_size, block_size, 0, musa_stream>>>(
reinterpret_cast<__mt_bfloat16 *>(output),
indices_ptr,
reinterpret_cast<const __mt_bfloat16 *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_input_dtype == INFINI_DTYPE_I64) {
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
if (_weight_dtype == INFINI_DTYPE_F32) {
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, musa_stream>>>(
reinterpret_cast<float *>(output),
indices_ptr,
reinterpret_cast<const float *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_F16) {
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, musa_stream>>>(
reinterpret_cast<half *>(output),
indices_ptr,
reinterpret_cast<const half *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
embeddingKernel<__mt_bfloat16, int64_t><<<grid_size, block_size, 0, musa_stream>>>(
reinterpret_cast<__mt_bfloat16 *>(output),
indices_ptr,
reinterpret_cast<const __mt_bfloat16 *>(weight),
_num_indices,
_embedding_dim,
_vocab_size);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// Check for kernel launch errors
musaError_t err = musaGetLastError();
if (err != musaSuccess) {
return INFINI_STATUS_INTERNAL_ERROR;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::embedding::moore
#ifndef __EMBEDDING_MOORE_KERNEL_CUH__
#define __EMBEDDING_MOORE_KERNEL_CUH__
#include <type_traits>
// Helper function to check memory alignment
__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) {
// Use size_t for pointer arithmetic in device code (more compatible)
return (reinterpret_cast<size_t>(ptr) % alignment == 0);
}
// Vectorized copy for float type using float4
template <typename IndexType>
__forceinline__ __device__ void copyVectorizedFloat4(
float *__restrict__ dst,
const float *__restrict__ src,
size_t embedding_dim) {
// Use float4 for vectorized access (16 bytes, 4 floats)
const float4 *src_vec = reinterpret_cast<const float4 *>(src);
float4 *dst_vec = reinterpret_cast<float4 *>(dst);
size_t vec_count = embedding_dim / 4;
// Vectorized copy with __ldg equivalent for Moore platform
for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = src_vec[i];
}
// Copy remaining elements
size_t remaining = embedding_dim % 4;
if (remaining > 0) {
size_t offset = vec_count * 4;
for (size_t i = 0; i < remaining; ++i) {
dst[offset + i] = src[offset + i];
}
}
}
// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes)
template <typename IndexType>
__forceinline__ __device__ void copyVectorizedFloat2(
float *__restrict__ dst,
const float *__restrict__ src,
size_t embedding_dim) {
// Use float2 for vectorized access (8 bytes, 2 floats)
const float2 *src_vec = reinterpret_cast<const float2 *>(src);
float2 *dst_vec = reinterpret_cast<float2 *>(dst);
size_t vec_count = embedding_dim / 2;
// Vectorized copy with Moore-optimized memory access
for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = src_vec[i];
}
// Copy remaining element if odd
if (embedding_dim % 2 != 0) {
dst[embedding_dim - 1] = src[embedding_dim - 1];
}
}
// Vectorized copy for half type using half2
template <typename IndexType>
__forceinline__ __device__ void copyVectorizedHalf2(
half *__restrict__ dst,
const half *__restrict__ src,
size_t embedding_dim) {
// Use half2 for vectorized access (4 bytes, 2 halfs)
const half2 *src_vec = reinterpret_cast<const half2 *>(src);
half2 *dst_vec = reinterpret_cast<half2 *>(dst);
size_t vec_count = embedding_dim / 2;
// Vectorized copy optimized for Moore architecture
for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = src_vec[i];
}
// Copy remaining element if odd
if (embedding_dim % 2 != 0) {
dst[embedding_dim - 1] = src[embedding_dim - 1];
}
}
// Vectorized copy for Moore bfloat16 type using bfloat162
template <typename IndexType>
__forceinline__ __device__ void copyVectorizedBFloat162(
__mt_bfloat16 *__restrict__ dst,
const __mt_bfloat16 *__restrict__ src,
size_t embedding_dim) {
// Use mt_bfloat162 for vectorized access (4 bytes, 2 bfloat16s)
const __mt_bfloat162 *src_vec = reinterpret_cast<const __mt_bfloat162 *>(src);
__mt_bfloat162 *dst_vec = reinterpret_cast<__mt_bfloat162 *>(dst);
size_t vec_count = embedding_dim / 2;
// Vectorized copy with Moore-specific optimization
for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = src_vec[i];
}
// Copy remaining element if odd
if (embedding_dim % 2 != 0) {
dst[embedding_dim - 1] = src[embedding_dim - 1];
}
}
// Scalar copy fallback with Moore-optimized memory access
template <typename T, typename IndexType>
__forceinline__ __device__ void copyScalar(
T *__restrict__ dst,
const T *__restrict__ src,
size_t embedding_dim) {
// Scalar copy with Moore read-only weight optimization
for (size_t i = 0; i < embedding_dim; ++i) {
dst[i] = src[i];
}
}
#endif // __EMBEDDING_MOORE_KERNEL_CUH__
...@@ -2,10 +2,65 @@ ...@@ -2,10 +2,65 @@
#include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "../../../tensor.h" #include "../../../tensor.h"
#include "embedding_kernel.cuh" #include "../cuda/embedding_kernel.cuh"
#include "embedding_nvidia.cuh" #include "embedding_nvidia.cuh"
#include <cuda_runtime.h> #include <cuda_runtime.h>
template <typename T, typename IndexType>
INFINIOP_CUDA_KERNEL embeddingKernel(
T *__restrict__ output,
const IndexType *__restrict__ indices,
const T *__restrict__ weight,
size_t num_indices,
size_t embedding_dim,
size_t vocab_size) {
// Calculate global thread index
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < num_indices) {
// Get the index value
IndexType index_val = __ldg(&indices[idx]);
// Bounds check - handle negative indices gracefully
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
// Copy embedding vector from weight to output
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
T *dst = output + idx * embedding_dim;
// Choose optimal copy strategy based on type and alignment
if constexpr (std::is_same_v<T, float>) {
// Check alignment for float4 (16 bytes)
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
// Try float2 if not aligned to 16 bytes
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, half>) {
// Use half2 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
// Use bfloat162 for vectorized access
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
} else {
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
} else {
// Fallback to scalar copy with __ldg
copyScalar<T, IndexType>(dst, src, embedding_dim);
}
}
}
}
namespace op::embedding::nvidia { namespace op::embedding::nvidia {
struct Descriptor::Opaque { struct Descriptor::Opaque {
......
...@@ -8,6 +8,12 @@ ...@@ -8,6 +8,12 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#include "nvidia/embedding_nvidia.cuh" #include "nvidia/embedding_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API
#include "metax/embedding_metax.cuh"
#endif
#ifdef ENABLE_MOORE_API
#include "moore/embedding_moore.h"
#endif
__C infiniStatus_t infiniopCreateEmbeddingDescriptor( __C infiniStatus_t infiniopCreateEmbeddingDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
...@@ -30,18 +36,24 @@ __C infiniStatus_t infiniopCreateEmbeddingDescriptor( ...@@ -30,18 +36,24 @@ __C infiniStatus_t infiniopCreateEmbeddingDescriptor(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CREATE(INFINI_DEVICE_CPU, cpu); CREATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia); CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#if defined(ENABLE_ILUVATAR_API) #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia); CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#if defined(ENABLE_QY_API) #ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia); CREATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#if defined(ENABLE_HYGON_API) #ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia); CREATE(INFINI_DEVICE_HYGON, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -67,18 +79,24 @@ __C infiniStatus_t infiniopEmbedding( ...@@ -67,18 +79,24 @@ __C infiniStatus_t infiniopEmbedding(
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
CALCULATE(INFINI_DEVICE_CPU, cpu); CALCULATE(INFINI_DEVICE_CPU, cpu);
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); CALCULATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif #endif
#if defined(ENABLE_ILUVATAR_API) #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif #endif
#if defined(ENABLE_QY_API) #ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia); CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif #endif
#if defined(ENABLE_HYGON_API) #ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia); CALCULATE(INFINI_DEVICE_HYGON, nvidia);
#endif #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore);
#endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -89,30 +107,39 @@ __C infiniStatus_t infiniopEmbedding( ...@@ -89,30 +107,39 @@ __C infiniStatus_t infiniopEmbedding(
__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { __C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \ #define DESTROY(CASE, NAMESPACE) \
case CASE: \ case CASE: \
delete reinterpret_cast<const op::embedding::NAMESPACE::Descriptor *>(desc); \ delete reinterpret_cast<const op::embedding::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
DELETE(INFINI_DEVICE_CPU, cpu); DESTROY(INFINI_DEVICE_CPU, cpu);
#endif #endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) #ifdef ENABLE_NVIDIA_API
DELETE(INFINI_DEVICE_NVIDIA, nvidia); DESTROY(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia);
#endif #endif
#if defined(ENABLE_ILUVATAR_API) #ifdef ENABLE_HYGON_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia); DESTROY(INFINI_DEVICE_HYGON, nvidia);
#endif #endif
#if defined(ENABLE_QY_API) #ifdef ENABLE_METAX_API
DELETE(INFINI_DEVICE_QY, nvidia); DESTROY(INFINI_DEVICE_METAX, metax);
#endif #endif
#if defined(ENABLE_HYGON_API) #ifdef ENABLE_MOORE_API
DELETE(INFINI_DEVICE_HYGON, nvidia); DESTROY(INFINI_DEVICE_MOORE, moore);
#endif #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
#undef DELETE #undef DESTROY
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment