Commit 5de45ee6 authored by yaoht's avatar yaoht
Browse files

接入fa,适配dcu,优化addrmsnorm和rope算子

parent 93191613
Pipeline #3510 failed with stages
in 0 seconds
......@@ -5,9 +5,12 @@
#include <ATen/ATen.h>
#ifdef ENABLE_NVIDIA_API
#if defined(ENABLE_NVIDIA_API)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#elif defined(ENABLE_HYGON_API)
#include <ATen/hip/HIPContext.h>
#include <c10/hip/HIPGuard.h>
#endif
namespace infinicore::adaptor {
......@@ -29,7 +32,8 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
}
inline at::Device to_at_device(const Device &device) {
if (device.getType() == Device::Type::NVIDIA) {
if (device.getType() == Device::Type::NVIDIA
|| device.getType() == Device::Type::HYGON) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
......@@ -40,8 +44,14 @@ inline at::Device to_at_device(const Device &device) {
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream();
#if defined(ENABLE_HYGON_API)
using TorchStream = c10::hip::HIPStream;
using TorchStreamGuard = c10::hip::HIPStreamGuard;
TorchStream get_cuda_stream();
#elif defined(ENABLE_NVIDIA_API)
using TorchStream = c10::cuda::CUDAStream;
using TorchStreamGuard = c10::cuda::CUDAStreamGuard;
TorchStream get_cuda_stream();
#endif
} // namespace infinicore::adaptor
......
#!/bin/bash
export CUDA_HOME=/opt/dtk/cuda/cuda
xmake clean --all
xmake f -c --hygon-dcu=y --ccl=y --graph=y --cuda=$CUDA_HOME --aten=y --flash-attn-prebuilt=/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so
xmake build && xmake install
xmake build _infinicore && xmake install _infinicore
pip install -e . --no-build-isolation
\ No newline at end of file
......@@ -32,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options);
}
#ifdef ENABLE_NVIDIA_API
c10::cuda::CUDAStream get_cuda_stream() {
#if defined(ENABLE_HYGON_API)
TorchStream get_cuda_stream() {
return c10::hip::getStreamFromExternal(
hipStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#elif defined(ENABLE_NVIDIA_API)
TorchStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
......
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include <dlfcn.h>
#include <optional>
#include <stdexcept>
#include <vector>
// ---------------------------------------------------------------------------
// Function pointer types for the extern "C" functions exported by the DCU
// flash_attn shared library (built from flash-attention-cutlass-master).
// We resolve these at runtime via dlsym to avoid hard link-time dependency
// on the prebuilt .so (which requires libtorch_python.so).
// ---------------------------------------------------------------------------
using mha_fwd_kvcache_fn_t = std::vector<at::Tensor> (*)(
at::Tensor &q,
const at::Tensor &kcache,
const at::Tensor &vcache,
c10::optional<const at::Tensor> &k_,
c10::optional<const at::Tensor> &v_,
c10::optional<const at::Tensor> &seqlens_k_,
c10::optional<const at::Tensor> &rotary_cos_,
c10::optional<const at::Tensor> &rotary_sin_,
c10::optional<const at::Tensor> &cache_batch_idx_,
c10::optional<const at::Tensor> &leftpad_k_,
c10::optional<at::Tensor> &block_table_,
c10::optional<at::Tensor> &alibi_slopes_,
c10::optional<at::Tensor> &out_,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved,
int num_splits,
const c10::optional<at::Tensor> &s_aux_);
using mha_varlen_fwd_fn_t = std::vector<at::Tensor> (*)(
at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
c10::optional<at::Tensor> &out_,
const at::Tensor &cu_seqlens_q,
const at::Tensor &cu_seqlens_k,
c10::optional<at::Tensor> &seqused_k,
c10::optional<const at::Tensor> &leftpad_k_,
c10::optional<at::Tensor> &block_table_,
c10::optional<at::Tensor> &alibi_slopes_,
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Tensor> q_descale_,
c10::optional<at::Tensor> k_descale_,
c10::optional<at::Tensor> v_descale_,
c10::optional<at::Generator> gen_,
const c10::optional<at::Tensor> &s_aux_);
static void *resolve_symbol(const char *name) {
void *sym = dlsym(RTLD_DEFAULT, name);
if (sym) {
return sym;
}
throw std::runtime_error(
std::string("flash_attn symbol not found: ") + name +
". Ensure flash_attn_2_cuda is loaded before calling this function "
"(e.g. import torch; import flash_attn_2_cuda).");
}
// ---------------------------------------------------------------------------
// Wrappers in the flash:: namespace.
// These match the signatures declared in
// include/infinicore/adaptor/flash_attention_adaptor.hpp
// and bridge the namespace gap between InfiniCore and the DCU library.
// ---------------------------------------------------------------------------
namespace flash {
std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q,
const at::Tensor &kcache,
const at::Tensor &vcache,
std::optional<const at::Tensor> &k_,
std::optional<const at::Tensor> &v_,
std::optional<const at::Tensor> &seqlens_k_,
std::optional<const at::Tensor> &rotary_cos_,
std::optional<const at::Tensor> &rotary_sin_,
std::optional<const at::Tensor> &cache_batch_idx_,
std::optional<const at::Tensor> &leftpad_k_,
std::optional<at::Tensor> &block_table_,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &out_,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved,
int num_splits) {
static auto fn = reinterpret_cast<mha_fwd_kvcache_fn_t>(
resolve_symbol("mha_fwd_kvcache"));
c10::optional<at::Tensor> s_aux = c10::nullopt;
return fn(
q, kcache, vcache,
k_, v_, seqlens_k_,
rotary_cos_, rotary_sin_, cache_batch_idx_, leftpad_k_,
block_table_, alibi_slopes_, out_,
softmax_scale, is_causal,
window_size_left, window_size_right,
softcap, is_rotary_interleaved, num_splits, s_aux);
}
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
std::optional<at::Tensor> &out_,
const at::Tensor &cu_seqlens_q,
const at::Tensor &cu_seqlens_k,
std::optional<at::Tensor> &seqused_k,
std::optional<const at::Tensor> &leftpad_k_,
std::optional<at::Tensor> &block_table_,
std::optional<at::Tensor> &alibi_slopes_,
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
static auto fn = reinterpret_cast<mha_varlen_fwd_fn_t>(
resolve_symbol("mha_varlen_fwd"));
c10::optional<at::Tensor> q_descale = c10::nullopt;
c10::optional<at::Tensor> k_descale = c10::nullopt;
c10::optional<at::Tensor> v_descale = c10::nullopt;
c10::optional<at::Tensor> s_aux = c10::nullopt;
return fn(
q, k, v, out_,
cu_seqlens_q, cu_seqlens_k,
seqused_k, leftpad_k_, block_table_, alibi_slopes_,
max_seqlen_q, max_seqlen_k,
p_dropout, softmax_scale, zero_tensors, is_causal,
window_size_left, window_size_right,
softcap, return_softmax,
q_descale, k_descale, v_descale, gen_, s_aux);
}
} // namespace flash
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
......@@ -3,6 +3,7 @@
#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include <infinirt.h>
#include <spdlog/spdlog.h>
namespace infinicore::graph {
......@@ -32,9 +33,11 @@ DispatchableGraphOperator::~DispatchableGraphOperator() {
* ========================= */
struct Graph::DeviceGraph {
infinirtGraph_t graph;
infinirtGraphExec_t exec;
infinirtGraphNode_t node;
infinirtGraph_t graph = nullptr;
infinirtGraphExec_t exec = nullptr;
infinirtGraphNode_t node = nullptr;
infinirtStream_t capture_stream = nullptr;
Device capture_device;
std::vector<char> log_buffer;
DeviceGraph() {
......@@ -51,7 +54,11 @@ struct Graph::DeviceGraph {
}
void launch() {
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream()));
// Ensure we are on the correct device before launching the graph
if (capture_device != context::getDevice()) {
context::setDevice(capture_device);
}
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, capture_stream));
}
};
......@@ -76,29 +83,41 @@ void Graph::instantiate() {
// Reset device graph
device_graph_ = std::make_unique<DeviceGraph>();
// warmup
// Save the current stream and device — all graph operations must use this stream
auto capture_stream = context::getStream();
auto capture_device = context::getDevice();
// warmup: ensure we are on the correct device and stream
context::setDevice(capture_device);
for (size_t iter = 0; iter < 5; ++iter) {
this->run();
}
infinicore::context::syncStream();
// Ensure device is correct before capture (may have been switched during warmup)
context::setDevice(capture_device);
if (infinirtStreamBeginCapture(
context::getStream(),
capture_stream,
INFINIRT_STREAM_CAPTURE_MODE_RELAXED)
!= INFINI_STATUS_SUCCESS) {
return;
}
// Run and record
// Run and record — all ops must use capture_stream
this->run();
if (infinirtStreamEndCapture(
context::getStream(),
capture_stream,
&device_graph_.get()->graph)
!= INFINI_STATUS_SUCCESS) {
return;
}
// Save the capture stream and device for later launch()
device_graph_.get()->capture_stream = capture_stream;
device_graph_.get()->capture_device = capture_device;
if (infinirtGraphInstantiate(
&device_graph_.get()->exec,
device_graph_.get()->graph,
......
......@@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings,
Tensor Embedding::forward(const Tensor &indices) const {
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto device_type = device_.getType();
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI) {
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI|| device_type == Device::Type::HYGON) {
// Use op::embedding which supports device-side input and batch dimension
return op::embedding(indices->contiguous()->to(device_), weight_);
}
......
......@@ -31,7 +31,9 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
|| device_.getType() == Device::Type::ILUVATAR
|| device_.getType() == Device::Type::METAX
|| device_.getType() == Device::Type::MOORE
|| device_.getType() == Device::Type::ALI) {
|| device_.getType() == Device::Type::ALI
|| device_.getType() == Device::Type::HYGON) {
// ){
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
} else {
op::add_(residual, x, residual);
......
......@@ -33,7 +33,7 @@ void *plan(Tensor out,
void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
......
......@@ -41,18 +41,25 @@ void *plan(Tensor out,
void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k);
auto v = infinicore::adaptor::to_aten_tensor(p->v);
auto k = infinicore::adaptor::to_aten_tensor(p->k).contiguous();
auto v = infinicore::adaptor::to_aten_tensor(p->v).contiguous();
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
// Flash-attn requires cu_seqlens and block_table on same device as q/k/v.
auto device = q.device();
if (!cu_seqlens_q.is_cuda()) cu_seqlens_q = cu_seqlens_q.to(device);
if (!cu_seqlens_kv.is_cuda()) cu_seqlens_kv = cu_seqlens_kv.to(device);
if (block_table.has_value() && !block_table->is_cuda()) block_table = block_table->to(device);
std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto max_seqlen_q = p->max_seqlen_q;
auto max_seqlen_k = p->max_seqlen_k;
auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt;
......
......@@ -60,4 +60,283 @@ __device__ void add_rmsnormBlock(
}
}
// dim=4096, block=1024 => 4 elements per thread: full unroll + register-held sums (no 2nd read of residual_out).
template <typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock_dim4096_bs1024(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
float epsilon) {
constexpr unsigned int BS = 1024;
constexpr size_t DIM = 4096;
const size_t batch_idx = blockIdx.x / nhead;
const size_t head_idx = blockIdx.x % nhead;
Tdata *y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
const Tdata *a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
const Tdata *b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
const Tweight *w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
const unsigned int t = threadIdx.x;
Tcompute s0 = Tcompute(a_ptr[t]) + Tcompute(b_ptr[t]);
Tcompute s1 = Tcompute(a_ptr[t + BS]) + Tcompute(b_ptr[t + BS]);
Tcompute s2 = Tcompute(a_ptr[t + 2 * BS]) + Tcompute(b_ptr[t + 2 * BS]);
Tcompute s3 = Tcompute(a_ptr[t + 3 * BS]) + Tcompute(b_ptr[t + 3 * BS]);
residual_out_ptr[t] = Tdata(s0);
residual_out_ptr[t + BS] = Tdata(s1);
residual_out_ptr[t + 2 * BS] = Tdata(s2);
residual_out_ptr[t + 3 * BS] = Tdata(s3);
Tcompute sum_squared = s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3;
using BlockReduce = cub::BlockReduce<Tcompute, BS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
__shared__ Tcompute rms;
if (t == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(DIM) + epsilon));
}
__syncthreads();
y_ptr[t] = Tdata(s0 * Tcompute(w_ptr[t]) * rms);
y_ptr[t + BS] = Tdata(s1 * Tcompute(w_ptr[t + BS]) * rms);
y_ptr[t + 2 * BS] = Tdata(s2 * Tcompute(w_ptr[t + 2 * BS]) * rms);
y_ptr[t + 3 * BS] = Tdata(s3 * Tcompute(w_ptr[t + 3 * BS]) * rms);
}
// dim=8192, block=1024 => 8 elements per thread: full unroll + register-held sums (no 2nd read of residual_out).
template <typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock_dim8192_bs1024(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
float epsilon) {
constexpr unsigned int BS = 1024;
constexpr size_t DIM = 8192;
const size_t batch_idx = blockIdx.x / nhead;
const size_t head_idx = blockIdx.x % nhead;
Tdata *y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
const Tdata *a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
const Tdata *b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
const Tweight *w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
const unsigned int t = threadIdx.x;
Tcompute s0 = Tcompute(a_ptr[t]) + Tcompute(b_ptr[t]);
Tcompute s1 = Tcompute(a_ptr[t + BS]) + Tcompute(b_ptr[t + BS]);
Tcompute s2 = Tcompute(a_ptr[t + 2 * BS]) + Tcompute(b_ptr[t + 2 * BS]);
Tcompute s3 = Tcompute(a_ptr[t + 3 * BS]) + Tcompute(b_ptr[t + 3 * BS]);
Tcompute s4 = Tcompute(a_ptr[t + 4 * BS]) + Tcompute(b_ptr[t + 4 * BS]);
Tcompute s5 = Tcompute(a_ptr[t + 5 * BS]) + Tcompute(b_ptr[t + 5 * BS]);
Tcompute s6 = Tcompute(a_ptr[t + 6 * BS]) + Tcompute(b_ptr[t + 6 * BS]);
Tcompute s7 = Tcompute(a_ptr[t + 7 * BS]) + Tcompute(b_ptr[t + 7 * BS]);
residual_out_ptr[t] = Tdata(s0);
residual_out_ptr[t + BS] = Tdata(s1);
residual_out_ptr[t + 2 * BS] = Tdata(s2);
residual_out_ptr[t + 3 * BS] = Tdata(s3);
residual_out_ptr[t + 4 * BS] = Tdata(s4);
residual_out_ptr[t + 5 * BS] = Tdata(s5);
residual_out_ptr[t + 6 * BS] = Tdata(s6);
residual_out_ptr[t + 7 * BS] = Tdata(s7);
Tcompute sum_squared =
s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3 + s4 * s4 + s5 * s5 + s6 * s6 + s7 * s7;
using BlockReduce = cub::BlockReduce<Tcompute, BS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
__shared__ Tcompute rms;
if (t == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(DIM) + epsilon));
}
__syncthreads();
y_ptr[t] = Tdata(s0 * Tcompute(w_ptr[t]) * rms);
y_ptr[t + BS] = Tdata(s1 * Tcompute(w_ptr[t + BS]) * rms);
y_ptr[t + 2 * BS] = Tdata(s2 * Tcompute(w_ptr[t + 2 * BS]) * rms);
y_ptr[t + 3 * BS] = Tdata(s3 * Tcompute(w_ptr[t + 3 * BS]) * rms);
y_ptr[t + 4 * BS] = Tdata(s4 * Tcompute(w_ptr[t + 4 * BS]) * rms);
y_ptr[t + 5 * BS] = Tdata(s5 * Tcompute(w_ptr[t + 5 * BS]) * rms);
y_ptr[t + 6 * BS] = Tdata(s6 * Tcompute(w_ptr[t + 6 * BS]) * rms);
y_ptr[t + 7 * BS] = Tdata(s7 * Tcompute(w_ptr[t + 7 * BS]) * rms);
}
#endif
//////////////////////////////////////////////////////////////////////
// #ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
// #define __ADD_RMS_NORM_CUDA_KERNEL_H__
// // 移除 cub 头文件依赖
// // #include <cub/block/block_reduce.cuh>
// template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
// __device__ void add_rmsnormBlock(
// Tdata * y, // 【修复 1】移除 __restrict__ 以支持 In-place
// Tdata * residual_out, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_y_batch,
// ptrdiff_t stride_y_nhead,
// ptrdiff_t stride_residual_out_batch,
// ptrdiff_t stride_residual_out_nhead,
// const Tdata * a, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_a_batch,
// ptrdiff_t stride_a_nhead,
// const Tdata * b, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_b_batch,
// ptrdiff_t stride_b_nhead,
// const Tweight *__restrict__ w, // 权重不被修改,保留 __restrict__ 是安全的
// size_t nhead,
// size_t dim,
// float epsilon) {
// size_t batch_idx = blockIdx.x / nhead;
// size_t head_idx = blockIdx.x % nhead;
// auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
// auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
// auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
// auto w_ptr = w;
// Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
// Tcompute sum_squared = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
// residual_out_ptr[i] = Tdata(sum_val); // Store add result
// sum_squared += sum_val * sum_val;
// }
// // 【修复 2】使用通用且安全的 Shared Memory 手动规约替换 cub::BlockReduce
// // 这样不会受制于特定设备的 Warp Size 差异导致死锁
// __shared__ Tcompute shared_sum[BLOCK_SIZE];
// shared_sum[threadIdx.x] = sum_squared;
// __syncthreads();
// #pragma unroll
// for (unsigned int offset = BLOCK_SIZE / 2; offset > 0; offset /= 2) {
// if (threadIdx.x < offset) {
// shared_sum[threadIdx.x] += shared_sum[threadIdx.x + offset];
// }
// __syncthreads();
// }
// sum_squared = shared_sum[0];
// __shared__ Tcompute rms;
// if (threadIdx.x == 0) {
// rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
// }
// __syncthreads();
// // 重新利用算出的 residual_out
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(residual_out_ptr[i]);
// y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
// }
// }
// #endif
////////////////////////////////////////////////////////////////////////////
// #ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
// #define __ADD_RMS_NORM_CUDA_KERNEL_H__
// #include <cub/block/block_reduce.cuh>
// // 假设每个线程最多处理的元素个数。
// // 例如 70B dim=8192, BLOCK_SIZE=1024,只需 8 个。设为 16 绝对够用。
// #define MAX_ELEMS_PER_THREAD 16
// template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
// __device__ void add_rmsnormBlock(
// Tdata *__restrict__ y,
// Tdata *__restrict__ residual_out,
// ptrdiff_t stride_y_batch,
// ptrdiff_t stride_y_seq, // 🌟 修正命名:通常是按 seq_len 划分,而不是 nhead
// ptrdiff_t stride_residual_out_batch,
// ptrdiff_t stride_residual_out_seq,
// const Tdata *__restrict__ a,
// ptrdiff_t stride_a_batch,
// ptrdiff_t stride_a_seq,
// const Tdata *__restrict__ b,
// ptrdiff_t stride_b_batch,
// ptrdiff_t stride_b_seq,
// const Tweight *__restrict__ w,
// size_t seq_len, // 🌟 修正命名:取代 nhead
// size_t dim,
// float epsilon) {
// // 🌟 一个 Block 处理一个 Token
// size_t batch_idx = blockIdx.x / seq_len;
// size_t seq_idx = blockIdx.x % seq_len;
// auto y_ptr = y + batch_idx * stride_y_batch + seq_idx * stride_y_seq;
// auto a_ptr = a + batch_idx * stride_a_batch + seq_idx * stride_a_seq;
// auto b_ptr = b + batch_idx * stride_b_batch + seq_idx * stride_b_seq;
// Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + seq_idx * stride_residual_out_seq;
// Tcompute sum_squared = 0;
// // 🌟 真融合核心:用寄存器数组缓存当前线程计算的加法结果!
// Tcompute thread_cache[MAX_ELEMS_PER_THREAD];
// int cache_idx = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
// residual_out_ptr[i] = Tdata(sum_val); // 依然写回全局显存供后续 Attention 使用
// thread_cache[cache_idx++] = sum_val; // 🌟 同时保存在极速寄存器中!
// sum_squared += sum_val * sum_val;
// }
// // Block 内规约求平方和
// using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
// __shared__ typename BlockReduce::TempStorage temp_storage;
// sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
// __shared__ Tcompute rms;
// if (threadIdx.x == 0) {
// rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
// }
// __syncthreads();
// // 🌟 第二阶段:直接从寄存器 `thread_cache` 读取,彻底干掉那次致命的显存读取!
// cache_idx = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// // 使用 __ldg (如果框架支持) 读取公共权重,速度拉满
// Tcompute weight_val = Tcompute(__ldg(&w[i]));
// y_ptr[i] = Tdata(thread_cache[cache_idx++] * weight_val * rms);
// }
// }
// #endif
\ No newline at end of file
......@@ -8,6 +8,217 @@
#include "../cuda/kernel.cuh"
// DIM=4096, block=1024, BF16: nv_bfloat162 + float regs; pair idx = tid + i*1024 (same reduction order as scalar fast path).
// (Contiguous longlong2 tiling changed per-thread partial sums order vs CUB reduce and broke bit-level match with reference runs.)
__device__ void add_rmsnormBlock_dim4096_bs1024_bf162_vec(
__nv_bfloat16 *__restrict__ y,
__nv_bfloat16 *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const __nv_bfloat16 *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const __nv_bfloat16 *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const __nv_bfloat16 *__restrict__ w,
size_t nhead,
float epsilon) {
constexpr unsigned int BS = 1024;
constexpr float DIM_F = 4096.0f;
const size_t batch_idx = blockIdx.x / nhead;
const size_t head_idx = blockIdx.x % nhead;
const __nv_bfloat16 *a_base = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
const __nv_bfloat16 *b_base = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
__nv_bfloat16 *res_base = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
__nv_bfloat16 *y_base = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
const __nv_bfloat162 *a_ptr2 = reinterpret_cast<const __nv_bfloat162 *>(a_base);
const __nv_bfloat162 *b_ptr2 = reinterpret_cast<const __nv_bfloat162 *>(b_base);
const __nv_bfloat162 *w_ptr2 = reinterpret_cast<const __nv_bfloat162 *>(w);
__nv_bfloat162 *res_ptr2 = reinterpret_cast<__nv_bfloat162 *>(res_base);
__nv_bfloat162 *y_ptr2 = reinterpret_cast<__nv_bfloat162 *>(y_base);
float sum_squared = 0.0f;
float s1_reg[2];
float s2_reg[2];
#pragma unroll
for (int i = 0; i < 2; ++i) {
const int idx = static_cast<int>(threadIdx.x) + i * static_cast<int>(BS);
const __nv_bfloat162 val_a = a_ptr2[idx];
const __nv_bfloat162 val_b = b_ptr2[idx];
const float f_a1 = __low2float(val_a);
const float f_a2 = __high2float(val_a);
const float f_b1 = __low2float(val_b);
const float f_b2 = __high2float(val_b);
const float t1 = f_a1 + f_b1;
const float t2 = f_a2 + f_b2;
s1_reg[i] = t1;
s2_reg[i] = t2;
res_ptr2[idx] = __floats2bfloat162_rn(t1, t2);
sum_squared += t1 * t1 + t2 * t2;
}
using BlockReduce = cub::BlockReduce<float, BS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
__shared__ float rms;
if (threadIdx.x == 0) {
rms = rsqrtf(sum_squared / DIM_F + epsilon);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < 2; ++i) {
const int idx = static_cast<int>(threadIdx.x) + i * static_cast<int>(BS);
const __nv_bfloat162 val_w = w_ptr2[idx];
const float f_w1 = __low2float(val_w);
const float f_w2 = __high2float(val_w);
const float y1 = s1_reg[i] * f_w1 * rms;
const float y2 = s2_reg[i] * f_w2 * rms;
y_ptr2[idx] = __floats2bfloat162_rn(y1, y2);
}
}
INFINIOP_CUDA_KERNEL add_rmsnormKernel_dim4096_bs1024_bf162_vec(
__nv_bfloat16 *__restrict__ y,
__nv_bfloat16 *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const __nv_bfloat16 *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const __nv_bfloat16 *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const __nv_bfloat16 *__restrict__ w,
size_t nhead,
float epsilon) {
add_rmsnormBlock_dim4096_bs1024_bf162_vec(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, epsilon);
}
// DIM=8192, block=1024: 4x nv_bfloat162 per thread; pair idx = tid + i*1024 (same as scalar tiling; avoids longlong2 reorder issues).
__device__ void add_rmsnormBlock_dim8192_bs1024_bf162_vec(
__nv_bfloat16 *__restrict__ y,
__nv_bfloat16 *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const __nv_bfloat16 *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const __nv_bfloat16 *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const __nv_bfloat16 *__restrict__ w,
size_t nhead,
float epsilon) {
constexpr unsigned int BS = 1024;
constexpr float DIM_F = 8192.0f;
const size_t batch_idx = blockIdx.x / nhead;
const size_t head_idx = blockIdx.x % nhead;
const __nv_bfloat16 *a_base = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
const __nv_bfloat16 *b_base = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
__nv_bfloat16 *res_base = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
__nv_bfloat16 *y_base = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
const __nv_bfloat162 *a_ptr2 = reinterpret_cast<const __nv_bfloat162 *>(a_base);
const __nv_bfloat162 *b_ptr2 = reinterpret_cast<const __nv_bfloat162 *>(b_base);
const __nv_bfloat162 *w_ptr2 = reinterpret_cast<const __nv_bfloat162 *>(w);
__nv_bfloat162 *res_ptr2 = reinterpret_cast<__nv_bfloat162 *>(res_base);
__nv_bfloat162 *y_ptr2 = reinterpret_cast<__nv_bfloat162 *>(y_base);
float sum_squared = 0.0f;
float s1_reg[4];
float s2_reg[4];
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int idx = static_cast<int>(threadIdx.x) + i * static_cast<int>(BS);
const __nv_bfloat162 val_a = a_ptr2[idx];
const __nv_bfloat162 val_b = b_ptr2[idx];
const float f_a1 = __low2float(val_a);
const float f_a2 = __high2float(val_a);
const float f_b1 = __low2float(val_b);
const float f_b2 = __high2float(val_b);
const float t1 = f_a1 + f_b1;
const float t2 = f_a2 + f_b2;
s1_reg[i] = t1;
s2_reg[i] = t2;
res_ptr2[idx] = __floats2bfloat162_rn(t1, t2);
sum_squared += t1 * t1 + t2 * t2;
}
using BlockReduce = cub::BlockReduce<float, BS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
__shared__ float rms;
if (threadIdx.x == 0) {
rms = rsqrtf(sum_squared / DIM_F + epsilon);
}
__syncthreads();
#pragma unroll
for (int i = 0; i < 4; ++i) {
const int idx = static_cast<int>(threadIdx.x) + i * static_cast<int>(BS);
const __nv_bfloat162 val_w = w_ptr2[idx];
const float f_w1 = __low2float(val_w);
const float f_w2 = __high2float(val_w);
const float y1 = s1_reg[i] * f_w1 * rms;
const float y2 = s2_reg[i] * f_w2 * rms;
y_ptr2[idx] = __floats2bfloat162_rn(y1, y2);
}
}
INFINIOP_CUDA_KERNEL add_rmsnormKernel_dim8192_bs1024_bf162_vec(
__nv_bfloat16 *__restrict__ y,
__nv_bfloat16 *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const __nv_bfloat16 *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const __nv_bfloat16 *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const __nv_bfloat16 *__restrict__ w,
size_t nhead,
float epsilon) {
add_rmsnormBlock_dim8192_bs1024_bf162_vec(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, epsilon);
}
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL add_rmsnormKernel(
Tdata *__restrict__ y,
......@@ -35,6 +246,58 @@ INFINIOP_CUDA_KERNEL add_rmsnormKernel(
w, nhead, dim, epsilon);
}
template <typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL add_rmsnormKernel_dim4096_bs1024(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
float epsilon) {
add_rmsnormBlock_dim4096_bs1024<Tcompute, Tdata, Tweight>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, epsilon);
}
template <typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL add_rmsnormKernel_dim8192_bs1024(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
float epsilon) {
add_rmsnormBlock_dim8192_bs1024<Tcompute, Tdata, Tweight>(
y, residual_out,
stride_y_batch, stride_y_nhead,
stride_residual_out_batch, stride_residual_out_nhead,
a, stride_a_batch, stride_a_nhead,
b, stride_b_batch, stride_b_nhead,
w, nhead, epsilon);
}
namespace op::add_rms_norm::nvidia {
struct Descriptor::Opaque {
......@@ -97,7 +360,115 @@ infiniStatus_t launchKernel(
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
#define LAUNCH_KERNEL_DIM4096_BS1024(Tdata, Tweight, Tcompute) \
add_rmsnormKernel_dim4096_bs1024<Tcompute, Tdata, Tweight><<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
epsilon)
#define LAUNCH_KERNEL_DIM4096_BS1024_BF162_VEC \
add_rmsnormKernel_dim4096_bs1024_bf162_vec<<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<__nv_bfloat16 *>(y), \
reinterpret_cast<__nv_bfloat16 *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(w), \
nhead, \
epsilon)
#define LAUNCH_KERNEL_DIM8192_BS1024(Tdata, Tweight, Tcompute) \
add_rmsnormKernel_dim8192_bs1024<Tcompute, Tdata, Tweight><<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
epsilon)
#define LAUNCH_KERNEL_DIM8192_BS1024_BF162_VEC \
add_rmsnormKernel_dim8192_bs1024_bf162_vec<<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<__nv_bfloat16 *>(y), \
reinterpret_cast<__nv_bfloat16 *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(w), \
nhead, \
epsilon)
if (dim == 4096 && BLOCK_SIZE == CUDA_BLOCK_SIZE_1024) {
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL_DIM4096_BS1024(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL_DIM4096_BS1024(half, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL_DIM4096_BS1024(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL_DIM4096_BS1024_BF162_VEC;
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL_DIM4096_BS1024(__nv_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL_DIM4096_BS1024(__nv_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL_DIM4096_BS1024(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (dim == 8192 && BLOCK_SIZE == CUDA_BLOCK_SIZE_1024) {
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL_DIM8192_BS1024(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL_DIM8192_BS1024(half, __nv_bfloat16, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL_DIM8192_BS1024(half, float, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL_DIM8192_BS1024_BF162_VEC;
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL_DIM8192_BS1024(__nv_bfloat16, half, float);
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL_DIM8192_BS1024(__nv_bfloat16, float, float);
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
LAUNCH_KERNEL_DIM8192_BS1024(float, float, float);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
LAUNCH_KERNEL(half, half, float);
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
LAUNCH_KERNEL(half, __nv_bfloat16, float);
......@@ -115,6 +486,10 @@ infiniStatus_t launchKernel(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
#undef LAUNCH_KERNEL_DIM8192_BS1024_BF162_VEC
#undef LAUNCH_KERNEL_DIM8192_BS1024
#undef LAUNCH_KERNEL_DIM4096_BS1024_BF162_VEC
#undef LAUNCH_KERNEL_DIM4096_BS1024
#undef LAUNCH_KERNEL
return INFINI_STATUS_SUCCESS;
......
......@@ -79,7 +79,8 @@ __device__ __forceinline__ float warpReduceMax(float x) {
}
__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
#if defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
// Iluvatar and Hygon DCU (HIP): use raw pointer cast instead of CUDA intrinsic.
return static_cast<unsigned int>(reinterpret_cast<uintptr_t>(ptr));
#else
return static_cast<unsigned int>(__cvta_generic_to_shared(ptr));
......@@ -87,7 +88,8 @@ __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
}
__device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// cp.async is NVIDIA PTX-only; Hygon DCU (HIP) must use plain loads instead.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
const unsigned int dst = cvtaToShared(dst_shared);
asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" ::"r"(dst), "l"(src_global));
#else
......@@ -98,14 +100,14 @@ __device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const
}
__device__ __forceinline__ void cpAsyncCommit() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
asm volatile("cp.async.commit_group;\n" ::);
#endif
}
template <int N>
__device__ __forceinline__ void cpAsyncWaitGroup() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif
}
......@@ -113,7 +115,7 @@ __device__ __forceinline__ void cpAsyncWaitGroup() {
// cp.async.wait_group requires a compile-time immediate, so for small fixed
// stage counts we provide a tiny runtime switch.
__device__ __forceinline__ void cpAsyncWaitGroupRt(int n) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
if (n <= 0) {
cpAsyncWaitGroup<0>();
} else if (n == 1) {
......@@ -1143,8 +1145,7 @@ __device__ void flashAttentionDecodeCtaPipelinedKernel(
// Prefetch the very first token.
int buf = 0;
int t_base = 0;
int token_in_block = 0;
(void)0; // t_base, token_in_block removed (unused)
int logical_block = 0;
{
if (tid == 0) {
......
......@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/paged_attention.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_MOORE_API
......@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#endif
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -112,6 +118,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttention(
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......@@ -141,6 +150,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
......@@ -2306,9 +2306,10 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
}
__syncthreads();
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && !defined(ENABLE_HYGON_API)
// WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows.
// For kBlockN=64, only the first 4 warps participate in WMMA score computation.
// nvcuda::wmma is NVIDIA-only; HIP/ROCm does not support it.
namespace wmma = nvcuda::wmma;
constexpr int kNSub = kBlockN / 16;
if (warp_id < kNSub) {
......
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include <cuda_runtime.h>
#include <cstdint>
......
......@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
......@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
......@@ -115,6 +121,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttentionPrefill(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
......@@ -144,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
......
......@@ -94,6 +94,21 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_cache_slot_stride,
v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
std::cout<< "NUM_THREADS: " << NUM_THREADS << std::endl;
std::cout<< "grid: " << grid.x << ", " << grid.y << ", " << grid.z << std::endl;
std::cout<< "block: " << block.x << ", " << block.y << ", " << block.z << std::endl;
std::cout<< "shared_mem_size: " << shared_mem_size << std::endl;
std::cout<< "slot_mapping: " << slot_mapping << std::endl;
std::cout<< "head_size: " << head_size << std::endl;
std::cout<< "block_size: " << block_size << std::endl;
std::cout<< "k_src_stride: " << k_src_stride << std::endl;
std::cout<< "v_src_stride: " << v_src_stride << std::endl;
std::cout<< "k_cache_block_stride: " << k_cache_block_stride << std::endl;
std::cout<< "v_cache_block_stride: " << v_cache_block_stride << std::endl;
std::cout<< "k_cache_head_stride: " << k_cache_head_stride << std::endl;
std::cout<< "v_cache_head_stride: " << v_cache_head_stride << std::endl;
std::cout<< "k_cache_slot_stride: " << k_cache_slot_stride << std::endl;
std::cout<< "v_cache_slot_stride: " << v_cache_slot_stride << std::endl;
pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)k_cache,
......
......@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
......@@ -41,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
......@@ -71,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
......@@ -105,6 +111,9 @@ __INFINI_C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif
......@@ -134,6 +143,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
......
......@@ -103,4 +103,50 @@ __device__ void ropeThreadPerItemBlock(
}
}
// grid_dim = dim3(info.seqlen, info.batch, 1);
// dim3 block_dim = dim3(info.table_dim, info.nhead, 1);
template <bool IsGPTJ, typename Tindex, typename Tangle>
__device__ void customropeThreadPerItemBlock(
cuda_bfloat16 *y_,
const cuda_bfloat16 *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids (0 if 1D)
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
const size_t batch_idx = blockIdx.y;
const size_t seq_idx = blockIdx.x;
const size_t head_idx = threadIdx.y;
const size_t dim_idx = threadIdx.x;
auto y_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
auto x_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
size_t pos_offset = batch_idx * pos_stride_batch + seq_idx;
size_t pos_id = size_t(pos_ids[pos_offset]);
auto table_offset = pos_id * table_dim;
Tangle sin__ = sin_table[table_offset + dim_idx];
Tangle cos__ = cos_table[table_offset + dim_idx];
size_t pos0 = dim_idx;
size_t pos1 = dim_idx + table_dim;
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
}
#endif
......@@ -33,6 +33,34 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
x_stride_batch, x_stride_seqlen, x_stride_nhead);
}
template <bool IsGPTJ, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL customropeThreadPerItemKernel(
cuda_bfloat16 *y_,
const cuda_bfloat16 *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
customropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
pos_stride_batch,
pos_has_batch_dim,
has_batch_dim,
y_stride_batch, y_stride_seqlen, y_stride_nhead,
x_stride_batch, x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::nvidia {
struct Descriptor::Opaque {
......@@ -96,9 +124,15 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
grid_dim = dim3(dimx, dimy, dimz);
} else {
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim = dim3(dimx, dimy);
grid_dim = dim3(dimx, dimy, 1);
}
// printf("block_size = %d info.table_dim = %ld has_batch_dim: %d, is_gpt_j: %d pos_has_batch_dim: %d\n",
// block_size, info.table_dim, info.has_batch_dim, is_gpt_j, info.pos_has_batch_dim);
// [batch, seqlen, nhead, dhead, table_len, table_dim, y_stride_batch, y_stride_seqlen, y_stride_nhead, x_stride_batch, x_stride_seqlen,x_stride_nhead]
// printf("[%ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld]\n", info.batch,
// info.seqlen, info.nhead, info.dhead, info.table_len, info.table_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
......@@ -108,13 +142,35 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
if ((std::is_same<Tdata, cuda_bfloat16>::value) && (info.table_dim == 64) && (info.nhead < 16) && (info.seqlen < 505)) {
auto bf16_y = reinterpret_cast<cuda_bfloat16*>(y);
auto bf16_x = reinterpret_cast<const cuda_bfloat16*>(x);
grid_dim = dim3(info.seqlen, info.batch, 1);
dim3 block_dim = dim3(info.table_dim, info.nhead, 1);
customropeThreadPerItemKernel<false><<<grid_dim, block_dim, 0, stream>>>(
bf16_y, bf16_x, pos_ids, sin_table, cos_table, info.table_dim,
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
// ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
// y, x, pos_ids, sin_table, cos_table, info.table_dim,
// pos_stride_batch,
// info.pos_has_batch_dim,
// info.has_batch_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim,
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
}
}
return INFINI_STATUS_SUCCESS;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment