Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/topkrouter_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/topkrouter_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -38,6 +38,9 @@ __C infiniStatus_t infiniopCreateTopkrouterDescriptor(infiniopHandle_t handle, i
#endif
#ifdef ENABLE_KUNLUN_API
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -67,6 +70,9 @@ __C infiniStatus_t infiniopGetTopkrouterWorkspaceSize(infiniopTopkrouterDescript
#endif
#ifdef ENABLE_KUNLUN_API
GET(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -99,6 +105,9 @@ __C infiniStatus_t infiniopTopkrouter(infiniopTopkrouterDescriptor_t desc, void
#endif
#ifdef ENABLE_KUNLUN_API
CALCULATE(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
}
......@@ -128,6 +137,9 @@ __C infiniStatus_t infiniopDestroyTopkrouterDescriptor(infiniopTopkrouterDescrip
#endif
#ifdef ENABLE_KUNLUN_API
DESTROY(INFINI_DEVICE_KUNLUN, kunlun);
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
}
......
......@@ -19,6 +19,15 @@ inline __device__ float exp_func(T x) {
return __expf(data);
}
// Warp-level sum reduction for Hygon platform
template <int warp_threads>
__inline__ __device__ float WarpSum(float val) {
for (int mask = warp_threads / 2; mask > 0; mask /= 2) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
template <typename T, int BLOCK_SIZE = 128>
__global__ void softmax_topk_row_kernel(float *values_topk, // 输出数据, 形状[N, topk]
int *indices_topk, // 输出索引, 形状[N, topk]
......@@ -57,6 +66,9 @@ __global__ void softmax_topk_row_kernel(float *values_topk, // 输出数据, 形
__shared__ typename BlockReduce::TempStorage temp_storage_max;
#if CUDART_VERSION >= 12090
T value_max = BlockReduce(temp_storage_max).Reduce(thread_max, ::cuda::maximum());
#elif defined(ENABLE_HYGON_API)
T value_max = BlockReduce(temp_storage_max).Reduce(
thread_max, [](const T &a, const T &b) { return (a > b) ? a : b; }, BLOCK_SIZE);
#else
T value_max = BlockReduce(temp_storage_max).Reduce(thread_max, cub::Max());
#endif
......@@ -117,12 +129,19 @@ __global__ void softmax_topk_row_kernel(float *values_topk, // 输出数据, 形
// 第五步: topk的和 //
// ------------------------------------------------ //
{
#ifdef ENABLE_HYGON_API
float warp_sum = WarpSum<32>(value);
if (0 == tid) {
shared_sum = warp_sum + 1e-9f;
}
#else
typedef cub::WarpReduce<float, 32> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
float warp_sum = WarpReduce(temp_storage).Sum(value);
if (0 == tid) {
shared_sum = warp_sum + 1e-9f;
}
#endif
}
__syncwarp();
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/topksoftmax_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#include "nvidia/topksoftmax_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -33,6 +33,12 @@ __C infiniStatus_t infiniopCreateTopksoftmaxDescriptor(infiniopHandle_t handle,
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
}
......@@ -60,6 +66,12 @@ __C infiniStatus_t infiniopGetTopksoftmaxWorkspaceSize(infiniopTopksoftmaxDescri
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
}
......@@ -92,6 +104,12 @@ __C infiniStatus_t infiniopTopksoftmax(infiniopTopksoftmaxDescriptor_t desc, voi
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
}
......@@ -119,6 +137,12 @@ __C infiniStatus_t infiniopDestroyTopksoftmaxDescriptor(infiniopTopksoftmaxDescr
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax);
#endif
#ifdef ENABLE_ALI_API
DESTROY(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
}
......
......@@ -27,8 +27,10 @@ public:
return 0;
} else if constexpr (std::is_same_v<T, uint64_t>) { // 10
return 0;
#ifndef ENABLE_HYGON_API
} else if constexpr (std::is_same_v<T, cuda_fp8_e4m3>) { // 11
return cuda_fp8_e4m3(0.0f);
#endif
} else if constexpr (std::is_same_v<T, half>) { // 12
return __float2half(0.0f);
} else if constexpr (std::is_same_v<T, float>) { // 13
......
......@@ -79,8 +79,10 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<256, cuda::ZerosOp, uint32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_U64: // 10
return _device_info->calculate<256, cuda::ZerosOp, uint64_t>(_info, workspace, output, inputs, stream);
#ifndef ENABLE_HYGON_API
case INFINI_DTYPE_F8: // 11
return _device_info->calculate<256, cuda::ZerosOp, cuda_fp8_e4m3>(_info, workspace, output, inputs, stream);
#endif
case INFINI_DTYPE_F16: // 12
return _device_info->calculate<256, cuda::ZerosOp, half>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_F32: // 13
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/zeros_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_ALI_API)
#include "nvidia/zeros_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -37,6 +37,9 @@ __C infiniStatus_t infiniopCreateZerosDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
......@@ -73,6 +76,9 @@ __C infiniStatus_t infiniopGetZerosWorkspaceSize(infiniopZerosDescriptor_t desc,
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
GET(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -114,6 +120,9 @@ __C infiniStatus_t infiniopZeros(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
CALCULATE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
CALCULATE(INFINI_DEVICE_QY, nvidia);
#endif
......@@ -149,6 +158,9 @@ infiniopDestroyZerosDescriptor(infiniopZerosDescriptor_t desc) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, nvidia);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, nvidia);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, nvidia);
#endif
......
......@@ -50,7 +50,7 @@ __mlu_func__ float sum(const T *source, T *src, float *dst, int num_elements, in
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -81,7 +81,7 @@ __mlu_func__ float sumBatched(const T *source, T *src, float *dst, int num_eleme
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -120,7 +120,7 @@ __mlu_func__ float sumSquared(const T *source, T *src, float *dst, int num_eleme
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -165,7 +165,7 @@ __mlu_func__ float sumSquaredBatched(const T *source, T *src, float *dst, int nu
size_t remainder = curr_batch % batch_size;
// Ensure NRAM buffer is zeroed
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
// Copy data to NRAM
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -235,7 +235,7 @@ __mlu_func__ float max(const T *source, T *src, float *dst, int num_elements, in
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......@@ -264,7 +264,7 @@ __mlu_func__ float maxBatched(const T *source, T *src, float *dst, int num_eleme
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
if (curr_batch < max_batch) {
__bang_write_zero(src, max_batch + offset);
__bang_write_value(src, max_batch + offset, 0);
}
__memcpy(src + offset, source + processed, curr_batch * sizeof(T), GDRAM2NRAM);
......
......@@ -23,6 +23,7 @@ void printUsage() {
<< " qy" << std::endl
<< " kunlun" << std::endl
<< " hygon" << std::endl
<< " ali" << std::endl
<< std::endl;
exit(EXIT_FAILURE);
}
......@@ -55,6 +56,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
else PARSE_DEVICE("--qy", INFINI_DEVICE_QY)
else PARSE_DEVICE("--kunlun", INFINI_DEVICE_KUNLUN)
else PARSE_DEVICE("--hygon", INFINI_DEVICE_HYGON)
else PARSE_DEVICE("--ali", INFINI_DEVICE_ALI)
else {
printUsage();
}
......
......@@ -21,6 +21,8 @@ namespace infinirt::iluvatar {
namespace infinirt::qy {
#elif defined(ENABLE_HYGON_API)
namespace infinirt::hygon {
#elif defined(ENABLE_ALI_API)
namespace infinirt::ali {
#else
namespace infinirt::cuda { // 默认回退
#endif
......
......@@ -38,4 +38,13 @@ INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::hygon
// ALI namespace
namespace infinirt::ali {
#ifdef ENABLE_ALI_API
INFINIRT_DEVICE_API_IMPL
#else
INFINIRT_DEVICE_API_NOOP
#endif
} // namespace infinirt::ali
#endif // __INFINIRT_CUDA_H__
......@@ -81,6 +81,9 @@ __C infiniStatus_t infinirtGetDevice(infiniDevice_t *device_ptr, int *device_id_
case INFINI_DEVICE_HYGON: \
_status = infinirt::hygon::API PARAMS; \
break; \
case INFINI_DEVICE_ALI: \
_status = infinirt::ali::API PARAMS; \
break; \
default: \
_status = INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; \
} \
......
......@@ -154,15 +154,32 @@ infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
}
infiniStatus_t streamBeginCapture(infinirtStream_t stream, infinirtStreamCaptureMode_t mode) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
hcStreamCaptureMode graph_mode;
if (mode == INFINIRT_STREAM_CAPTURE_MODE_GLOBAL) {
graph_mode = hcStreamCaptureModeGlobal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_THREAD_LOCAL) {
graph_mode = hcStreamCaptureModeThreadLocal;
} else if (mode == INFINIRT_STREAM_CAPTURE_MODE_RELAXED) {
graph_mode = hcStreamCaptureModeRelaxed;
} else {
return INFINI_STATUS_BAD_PARAM;
}
CHECK_MACART(hcStreamBeginCapture((hcStream_t)stream, graph_mode));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t streamEndCapture(infinirtStream_t stream, infinirtGraph_t *graph_ptr) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
hcGraph_t graph;
CHECK_MACART(hcStreamEndCapture((hcStream_t)stream, &graph));
*graph_ptr = graph;
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphDestroy(infinirtGraph_t graph) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
CHECK_MACART(hcGraphDestroy((hcGraph_t)graph));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphInstantiate(
......@@ -171,15 +188,23 @@ infiniStatus_t graphInstantiate(
infinirtGraphNode_t *node_ptr,
char *log_buffer,
size_t buffer_size) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
CHECK_MACART(hcGraphInstantiate(
(hcGraphExec_t *)graph_exec_ptr,
(hcGraph_t)graph,
(hcGraphNode_t *)node_ptr,
log_buffer,
buffer_size));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphExecDestroy(infinirtGraphExec_t graph_exec) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
CHECK_MACART(hcGraphExecDestroy((hcGraphExec_t)graph_exec));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t graphLuanch(infinirtGraphExec_t graph_exec, infinirtStream_t stream) {
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
CHECK_MACART(hcGraphLaunch((hcGraphExec_t)graph_exec, (hcStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
} // namespace infinirt::metax
......@@ -342,7 +342,10 @@ class BaseOperatorTest(ABC):
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
if comparison_target == i or (
isinstance(comparison_target, (list, tuple))
and i in comparison_target
):
cloned_inp = clone_torch_tensor(inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
......@@ -508,7 +511,9 @@ class BaseOperatorTest(ABC):
# Handle multiple outputs comparison
# Determine what to compare based on comparison_target
if comparison_target is None:
if comparison_target is None or isinstance(
comparison_target, (list, tuple)
):
# Compare return values (out-of-place multiple outputs)
torch_comparison = torch_result
infini_comparison = infini_result
......@@ -573,7 +578,9 @@ class BaseOperatorTest(ABC):
# ==========================================================================
else:
# Determine comparison targets for single output
if comparison_target is None:
if comparison_target is None or isinstance(
comparison_target, (list, tuple)
):
# Compare return values (out-of-place)
torch_comparison = torch_result
infini_comparison = infini_result
......
......@@ -24,6 +24,7 @@ def get_supported_hardware_platforms():
("--kunlun", "Kunlun XPUs (requires torch_xmlir)"),
("--hygon", "Hygon DCUs"),
("--qy", "QY GPUs"),
("--ali", "Ali PPU accelerators"),
]
......@@ -230,13 +231,21 @@ def get_test_devices(args):
if args.qy:
try:
# Iluvatar GPU detection
# QY GPU detection
import torch
devices_to_test.append(InfiniDeviceEnum.QY)
except ImportError:
print("Warning: QY GPU support not available")
if args.ali:
try:
import torch
devices_to_test.append(InfiniDeviceEnum.ALI)
except ImportError:
print("Warning: Ali PPU support not available")
# Default to CPU if no devices specified
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
......
......@@ -9,6 +9,7 @@ class InfiniDeviceEnum:
KUNLUN = 7
HYGON = 8
QY = 9
ALI = 10
InfiniDeviceNames = {
......@@ -22,6 +23,7 @@ InfiniDeviceNames = {
InfiniDeviceEnum.QY: "Qy",
InfiniDeviceEnum.KUNLUN: "Kunlun",
InfiniDeviceEnum.HYGON: "Hygon",
InfiniDeviceEnum.ALI: "Ali",
}
torch_device_map = {
......@@ -35,4 +37,5 @@ torch_device_map = {
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.HYGON: "cuda",
InfiniDeviceEnum.QY: "cuda",
InfiniDeviceEnum.ALI: "cuda",
}
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework import BaseOperatorTest, GenericTestRunner, TensorSpec, TestCase
from framework.tensor import TensorInitializer
import infinicore
# Test cases format: (nlayers, batch_size, hidden_size, nhead, nkvhead, dim, seqlen, past_seqlen, max_seqlen)
_TEST_CASES_DATA = [
(28, 1, 3584, 28, 28, 128, 1, 256, 512),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-4, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 1e-4, "rtol": 5e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def parse_test_cases():
cases = []
for (
nlayers,
batch_size,
hidden_size,
nhead,
nkvhead,
dim,
seqlen,
past_seqlen,
max_seqlen,
) in _TEST_CASES_DATA:
for dtype in _TENSOR_DTYPES:
tol = _TOLERANCE_MAP[dtype]
hidden_states = TensorSpec.from_tensor(
(batch_size, seqlen, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
pos_ids = TensorSpec.from_tensor(
(batch_size, seqlen),
dtype=infinicore.int64,
init_mode=TensorInitializer.RANDINT,
low=0,
high=max_seqlen,
)
k_cache = TensorSpec.from_tensor(
(nlayers, batch_size, nkvhead, max_seqlen, dim),
dtype=dtype,
scale=1e-1,
bias=-5e-2,
)
v_cache = TensorSpec.from_tensor(
(nlayers, batch_size, nkvhead, max_seqlen, dim),
dtype=dtype,
scale=1e-1,
bias=-5e-2,
)
q_proj_w = TensorSpec.from_tensor(
(nhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
k_proj_w = TensorSpec.from_tensor(
(nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
v_proj_w = TensorSpec.from_tensor(
(nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2
)
o_proj_w = TensorSpec.from_tensor(
(hidden_size, nhead * dim), dtype=dtype, scale=1e-1, bias=-5e-2
)
norm_w = TensorSpec.from_tensor(
(hidden_size,), dtype=dtype, scale=1e-1, bias=-5e-2
)
sin_table = TensorSpec.from_tensor(
(max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2
)
cos_table = TensorSpec.from_tensor(
(max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2
)
# Out-of-place
cases.append(
TestCase(
inputs=[
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tol,
description="Graph",
)
)
return cases
def torch_rope(
q: torch.Tensor,
k: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
pos_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
q, k: [B, H, S, D]
sin, cos: [max_S, D//2]
pos_ids: [B, S]
"""
def rotate_half(x: torch.Tensor) -> torch.Tensor:
# x: [..., head_dim]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
return torch.stack((-x_odd, x_even), dim=-1).flatten(-2)
B, H, S, D = q.shape
assert D % 2 == 0
# Gather sin/cos by position
# -> [B, S, D//2]
sin = sin[pos_ids]
cos = cos[pos_ids]
# Expand to broadcast over heads
# -> [B, 1, S, D//2]
sin = sin.unsqueeze(1)
cos = cos.unsqueeze(1)
# Interleave to full dim
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
# Apply RoPE
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class OpTest(BaseOperatorTest):
"""Test Operator Graph"""
def __init__(self):
super().__init__("Graph")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(
self,
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
**kwargs,
):
B, S, D = hidden_states.shape
for layer in range(nlayers):
# ---- RMSNorm ----
var = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(var + 1e-5) * norm_w
# ---- QKV projection ----
q = hidden_states @ q_proj_w.T
k = hidden_states @ k_proj_w.T
v = hidden_states @ v_proj_w.T
q = q.view(B, S, nhead, dim).transpose(1, 2) # [B,H,S,Dh]
k = k.view(B, S, nkvhead, dim).transpose(1, 2)
v = v.view(B, S, nkvhead, dim).transpose(1, 2)
# ---- RoPE ----
q, k = torch_rope(
q,
k,
sin_table,
cos_table,
pos_ids,
)
# ---- KV cache update ----
k_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = k
v_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = v
K = k_cache[layer, :, :, 0 : past_seqlen + S, :]
V = v_cache[layer, :, :, 0 : past_seqlen + S, :]
# ---- Scaled Dot Product Attention (fused) ----
def scaled_dot_product_attention(
query, key, value, is_causal=False, enable_gqa=False
) -> torch.Tensor:
S, L = query.size(-2), key.size(-2)
scale_factor = query.size(-1) ** -0.5
attn_bias = torch.zeros(S, L, dtype=query.dtype, device=query.device)
if is_causal:
mask = torch.tril(attn_bias + 1, diagonal=-1).flip(dims=[-2, -1])
attn_bias = torch.where(mask == 1, -torch.inf, 0.0)
if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(
query.size(-3) // value.size(-3), -3
)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
attn_out = scaled_dot_product_attention(
q,
K,
V,
is_causal=True,
enable_gqa=True,
) # [B,H,S,Dh]
# ---- Output projection ----
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(B, S, nhead * dim)
hidden_states = attn_out @ o_proj_w.T
return hidden_states
def infinicore_operator(
self,
hidden_states,
pos_ids,
nhead,
nkvhead,
dim,
past_seqlen,
nlayers,
k_cache,
v_cache,
q_proj_w,
k_proj_w,
v_proj_w,
o_proj_w,
norm_w,
sin_table,
cos_table,
**kwargs,
):
"""Record graph and run"""
input_hidden_states = hidden_states
B, S, D = input_hidden_states.shape
infinicore.start_graph_recording()
for layer in range(nlayers):
hidden_states = infinicore.nn.functional.rms_norm(
hidden_states, norm_w.shape, norm_w, 1e-5
)
q = infinicore.nn.functional.linear(hidden_states, q_proj_w)
k = infinicore.nn.functional.linear(hidden_states, k_proj_w)
v = infinicore.nn.functional.linear(hidden_states, v_proj_w)
q = q.view((B, S, nhead, dim))
k = k.view((B, S, nkvhead, dim))
v = v.view((B, S, nkvhead, dim))
q = infinicore.nn.functional.rope(
q,
pos_ids,
sin_table,
cos_table,
infinicore.nn.functional.RopeAlgo.GPT_J,
)
k = infinicore.nn.functional.rope(
k,
pos_ids,
sin_table,
cos_table,
infinicore.nn.functional.RopeAlgo.GPT_J,
)
# [B, KVH, total_len, D]
full_k = (
k_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S)
)
full_v = (
v_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S)
)
full_k.narrow(2, past_seqlen, S).copy_(k.permute((0, 2, 1, 3)))
full_v.narrow(2, past_seqlen, S).copy_(v.permute((0, 2, 1, 3)))
G = nhead // nkvhead
L = past_seqlen + S
full_q = (
q.permute((0, 2, 1, 3)).contiguous().view((B * nkvhead, G * S, dim))
)
full_k = full_k.view((B * nkvhead, L, dim))
full_v = full_v.view((B * nkvhead, L, dim))
attn_score = infinicore.matmul(
full_q, full_k.permute((0, 2, 1)), alpha=dim**-0.5
)
# [B * H, S, total_len]
attn_score = attn_score.view((B * nhead, S, L))
infinicore.nn.functional.causal_softmax(attn_score, out=attn_score)
attn_out = infinicore.matmul(attn_score, full_v)
attn_out = (
attn_out.view((B, nhead, S, dim))
.permute((0, 2, 1, 3))
.contiguous()
.view((B, S, nhead * dim))
)
hidden_states = infinicore.nn.functional.linear(attn_out, o_proj_w)
op_graph = infinicore.stop_graph_recording()
op_graph.run()
return hidden_states
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (in_shape, proj_w_shape)
_TEST_CASES_DATA = [
((32, 4096), (4096, 4096)),
]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-4, "rtol": 1e-3},
infinicore.bfloat16: {"atol": 0, "rtol": 5e-2},
}
_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16]
def parse_test_cases():
cases = []
for in_shape, proj_w_shape in _TEST_CASES_DATA:
for dtype in _TENSOR_DTYPES:
tol = _TOLERANCE_MAP[dtype]
in_spec = TensorSpec.from_tensor(in_shape, dtype=dtype)
proj_w_spec = TensorSpec.from_tensor(proj_w_shape, dtype=dtype)
temp_spec = TensorSpec.from_tensor(in_shape, dtype=dtype)
# Out-of-place
cases.append(
TestCase(
inputs=[in_spec, proj_w_spec, temp_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tol,
description="Graph",
)
)
return cases
class OpTest(BaseOperatorTest):
"""Test Operator Graph"""
def __init__(self):
super().__init__("Graph")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
a = args[0]
b = args[1]
return torch.matmul(a, b)
def infinicore_operator(self, *args, **kwargs):
"""Record graph and run"""
a = args[0]
b = args[1]
temp_a = args[2]
infinicore.start_graph_recording()
c = infinicore.matmul(temp_a, b)
op_graph = infinicore.stop_graph_recording()
temp_a.copy_(a)
op_graph.run()
return c
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
"""
Test if embedding supports CUDA Graph recording
Usage:
python test/infinicore/nn/test_embedding_graph_recording.py
Key verification points:
1. Before modification: indices->to(cpu_device) triggers synchronous D2H copy, causing graph recording to fail
2. After modification: Uses device-side CUDA kernel, fully asynchronous, supports graph recording
Expected results:
- Before modification: Graph recording fails, device-side input may fail
- After modification: Graph recording succeeds, device-side input succeeds
"""
import infinicore
import torch
def test_embedding_graph_recording():
"""Test if embedding supports CUDA Graph recording"""
print("=" * 60)
print("Testing Embedding Graph Recording Support")
print("=" * 60)
# Check if CUDA is available
if not torch.cuda.is_available():
print("⚠ CUDA not available, skipping graph recording test")
return False
device = infinicore.device("cuda", 0)
# Create embedding module
vocab_size = 1000
embedding_dim = 128
embedding = infinicore.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
dtype=infinicore.float32,
device=device,
)
# Create device-side input_ids (key point: unsupported before modification, supported after)
batch_size = 4
seq_len = 32
input_ids_device = infinicore.from_list(
[[i % vocab_size for i in range(seq_len)] for _ in range(batch_size)],
dtype=infinicore.int64,
device=device,
)
print(f"\n1. Input tensor information:")
print(f" - Shape: {input_ids_device.shape}")
print(f" - Device: {input_ids_device.device.type}")
print(f" - Dtype: {input_ids_device.dtype}")
# Attempt CUDA Graph recording
print(f"\n2. Attempting CUDA Graph recording...")
# Use PyTorch's CUDA Graph API for testing (simpler and more reliable)
try:
# Set device
infinicore.set_device(device)
# Use PyTorch's CUDA Graph API
# Note: PyTorch 2.0+ supports torch.cuda.graph
try:
# Method 1: Use PyTorch CUDA Graph (recommended)
print(" Using PyTorch CUDA Graph API for testing...")
# Create warmup input
warmup_input = input_ids_device
# Warmup (need to execute once before graph recording, including memory allocation)
embedding.forward(warmup_input)
infinicore.sync_stream() # Synchronize to ensure warmup completes
# Pre-allocate output tensor (CUDA Graph doesn't support dynamic memory allocation)
# Output shape: input_shape + [embedding_dim]
output_shape = list(input_ids_device.shape) + [embedding_dim]
output = infinicore.empty(
output_shape, dtype=embedding.weight.dtype, device=device
)
# Warmup embedding (ensure memory allocation is complete)
import infinicore.nn.functional as F
F.embedding(warmup_input, embedding.weight, out=output)
infinicore.sync_stream()
# Start graph recording (using pre-allocated output)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# Use embedding's out parameter (in-place), passing pre-allocated output
F.embedding(input_ids_device, embedding.weight, out=output)
print(" ✓ Graph recording successful!")
print(" ✓ Embedding supports CUDA Graph recording")
# Verify graph can be replayed
graph.replay()
infinicore.sync_stream()
print(" ✓ Graph can be successfully replayed")
return True
except AttributeError:
# PyTorch version may not support torch.cuda.graph
print(
" ⚠ PyTorch version doesn't support torch.cuda.graph, using simplified verification method"
)
return test_embedding_async_verification(embedding, input_ids_device)
except RuntimeError as e:
error_msg = str(e)
if "capture" in error_msg.lower() or "graph" in error_msg.lower():
print(f" ✗ Graph recording failed: {e}")
print(
" ✗ Embedding doesn't support CUDA Graph recording (may contain synchronous operations)"
)
return False
else:
print(f" ⚠ Graph recording test exception: {e}")
return test_embedding_async_verification(embedding, input_ids_device)
except Exception as e:
print(f" ⚠ Graph recording test exception: {e}")
print(" Using simplified verification method...")
import traceback
traceback.print_exc()
return test_embedding_async_verification(embedding, input_ids_device)
def test_embedding_async_verification(embedding, input_ids_device):
"""
Simplified verification: Check if there are synchronous operations
Key checkpoints:
1. Whether input can be on device (needed CPU before modification, supports device after)
2. Whether operations are fully asynchronous (no synchronization points)
"""
print("\n3. Simplified verification: Checking asynchronous operation support")
# Verification 1: Input can be on device
if input_ids_device.device.type != "cuda":
print(" ✗ Input not on device, cannot verify")
return False
print(" ✓ Input is on device")
# Verification 2: Execute forward, check for synchronous operations
# Before modification, this would call indices->to(cpu_device), triggering synchronization
# After modification, directly uses device-side kernel, fully asynchronous
try:
# Record start time
start_event = infinicore.DeviceEvent(enable_timing=True)
end_event = infinicore.DeviceEvent(enable_timing=True)
start_event.record()
output = embedding.forward(input_ids_device)
end_event.record()
# Don't synchronize immediately, check if operation is asynchronous
# If operation is asynchronous, query should return False (not completed)
# If operation is synchronous, may have already completed
# Wait a short time
import time
time.sleep(0.001) # 1ms
# Check event status
is_complete = end_event.query()
if not is_complete:
print(" ✓ Operation is asynchronous (event not immediately completed)")
else:
print(
" ⚠ Operation may contain synchronization points (event immediately completed)"
)
# Synchronize and measure time
end_event.synchronize()
elapsed = start_event.elapsed_time(end_event)
print(f" ✓ Forward execution time: {elapsed:.3f} ms")
print(f" ✓ Output shape: {output.shape}")
print(f" ✓ Output device: {output.device.type}")
# Verify output correctness
embedding_dim = embedding.embedding_dim()
expected_shape = (*input_ids_device.shape, embedding_dim)
if output.device.type == "cuda" and output.shape == expected_shape:
print(" ✓ Output on device, shape correct")
return True
else:
print(f" ✗ Output verification failed")
print(
f" Expected shape: {expected_shape}, actual shape: {output.shape}"
)
print(f" Expected device: cuda, actual device: {output.device.type}")
return False
except Exception as e:
print(f" ✗ Verification failed: {e}")
import traceback
traceback.print_exc()
return False
def test_embedding_device_input_support():
"""Test if embedding supports device-side input"""
print("\n" + "=" * 60)
print("Testing Embedding Device-side Input Support")
print("=" * 60)
if not torch.cuda.is_available():
print("⚠ CUDA not available, skipping test")
return False
device = infinicore.device("cuda", 0)
vocab_size = 100
embedding_dim = 64
embedding = infinicore.nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
dtype=infinicore.float32,
device=device,
)
# Test 1: Device-side input (supported after modification)
print("\nTest 1: Device-side input")
try:
input_ids_device = infinicore.from_list(
[[1, 2, 3, 4, 5]], dtype=infinicore.int64, device=device
)
output = embedding.forward(input_ids_device)
print(f" ✓ Device-side input successful")
print(f" - Input device: {input_ids_device.device.type}")
print(f" - Output device: {output.device.type}")
print(f" - Output shape: {output.shape}")
return True
except Exception as e:
print(f" ✗ Device-side input failed: {e}")
return False
def main():
"""Main test function"""
print("\n" + "=" * 60)
print("Embedding Graph Recording Support Verification")
print("=" * 60)
results = []
# Test 1: Graph recording support
result1 = test_embedding_graph_recording()
results.append(("CUDA Graph Recording", result1))
# Test 2: Device-side input support
result2 = test_embedding_device_input_support()
results.append(("Device-side Input", result2))
# Summary
print("\n" + "=" * 60)
print("Test Results Summary")
print("=" * 60)
all_passed = True
for test_name, result in results:
status = "✓ Passed" if result else "✗ Failed"
print(f"{test_name}: {status}")
if not result:
all_passed = False
print("\n" + "=" * 60)
if all_passed:
print("✓ All tests passed! Embedding supports graph recording")
else:
print("✗ Some tests failed, embedding may not fully support graph recording")
print("=" * 60)
return all_passed
if __name__ == "__main__":
success = main()
exit(0 if success else 1)
......@@ -114,14 +114,9 @@ class OpTest(BaseOperatorTest):
def infinicore_operator(self, x, weight):
"""InfiniCore nn.Embedding implementation"""
if x.device.type != "cpu":
# 将 input的数据 转移到 cpu 上
x_torch = convert_infinicore_to_torch(x)
x_torch_cpu = x_torch.contiguous().cpu()
x = infinicore.from_torch(x_torch_cpu)
# Note: embedding now supports device-side input for graph recording
# No need to convert to CPU anymore - the implementation handles both CPU and device inputs
num_embeddings, embedding_dim = weight.shape
model = infinicore.nn.Embedding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment