Commit 5de45ee6 authored by yaoht's avatar yaoht
Browse files

接入fa,适配dcu,优化addrmsnorm和rope算子

parent 93191613
Pipeline #3510 failed with stages
in 0 seconds
...@@ -5,9 +5,12 @@ ...@@ -5,9 +5,12 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#ifdef ENABLE_NVIDIA_API #if defined(ENABLE_NVIDIA_API)
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#elif defined(ENABLE_HYGON_API)
#include <ATen/hip/HIPContext.h>
#include <c10/hip/HIPGuard.h>
#endif #endif
namespace infinicore::adaptor { namespace infinicore::adaptor {
...@@ -29,7 +32,8 @@ inline at::ScalarType to_at_dtype(DataType dtype) { ...@@ -29,7 +32,8 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
} }
inline at::Device to_at_device(const Device &device) { inline at::Device to_at_device(const Device &device) {
if (device.getType() == Device::Type::NVIDIA) { if (device.getType() == Device::Type::NVIDIA
|| device.getType() == Device::Type::HYGON) {
return at::Device(at::kCUDA, device.getIndex()); return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) { } else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU); return at::Device(at::kCPU);
...@@ -40,8 +44,14 @@ inline at::Device to_at_device(const Device &device) { ...@@ -40,8 +44,14 @@ inline at::Device to_at_device(const Device &device) {
at::Tensor to_aten_tensor(const infinicore::Tensor &t); at::Tensor to_aten_tensor(const infinicore::Tensor &t);
#ifdef ENABLE_NVIDIA_API #if defined(ENABLE_HYGON_API)
c10::cuda::CUDAStream get_cuda_stream(); using TorchStream = c10::hip::HIPStream;
using TorchStreamGuard = c10::hip::HIPStreamGuard;
TorchStream get_cuda_stream();
#elif defined(ENABLE_NVIDIA_API)
using TorchStream = c10::cuda::CUDAStream;
using TorchStreamGuard = c10::cuda::CUDAStreamGuard;
TorchStream get_cuda_stream();
#endif #endif
} // namespace infinicore::adaptor } // namespace infinicore::adaptor
......
#!/bin/bash
export CUDA_HOME=/opt/dtk/cuda/cuda
xmake clean --all
xmake f -c --hygon-dcu=y --ccl=y --graph=y --cuda=$CUDA_HOME --aten=y --flash-attn-prebuilt=/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so
xmake build && xmake install
xmake build _infinicore && xmake install _infinicore
pip install -e . --no-build-isolation
\ No newline at end of file
...@@ -32,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { ...@@ -32,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options); options);
} }
#ifdef ENABLE_NVIDIA_API #if defined(ENABLE_HYGON_API)
c10::cuda::CUDAStream get_cuda_stream() { TorchStream get_cuda_stream() {
return c10::hip::getStreamFromExternal(
hipStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
#elif defined(ENABLE_NVIDIA_API)
TorchStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal( return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
} }
......
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include <dlfcn.h>
#include <optional>
#include <stdexcept>
#include <vector>
// ---------------------------------------------------------------------------
// Function pointer types for the extern "C" functions exported by the DCU
// flash_attn shared library (built from flash-attention-cutlass-master).
// We resolve these at runtime via dlsym to avoid hard link-time dependency
// on the prebuilt .so (which requires libtorch_python.so).
// ---------------------------------------------------------------------------
using mha_fwd_kvcache_fn_t = std::vector<at::Tensor> (*)(
at::Tensor &q,
const at::Tensor &kcache,
const at::Tensor &vcache,
c10::optional<const at::Tensor> &k_,
c10::optional<const at::Tensor> &v_,
c10::optional<const at::Tensor> &seqlens_k_,
c10::optional<const at::Tensor> &rotary_cos_,
c10::optional<const at::Tensor> &rotary_sin_,
c10::optional<const at::Tensor> &cache_batch_idx_,
c10::optional<const at::Tensor> &leftpad_k_,
c10::optional<at::Tensor> &block_table_,
c10::optional<at::Tensor> &alibi_slopes_,
c10::optional<at::Tensor> &out_,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved,
int num_splits,
const c10::optional<at::Tensor> &s_aux_);
using mha_varlen_fwd_fn_t = std::vector<at::Tensor> (*)(
at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
c10::optional<at::Tensor> &out_,
const at::Tensor &cu_seqlens_q,
const at::Tensor &cu_seqlens_k,
c10::optional<at::Tensor> &seqused_k,
c10::optional<const at::Tensor> &leftpad_k_,
c10::optional<at::Tensor> &block_table_,
c10::optional<at::Tensor> &alibi_slopes_,
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Tensor> q_descale_,
c10::optional<at::Tensor> k_descale_,
c10::optional<at::Tensor> v_descale_,
c10::optional<at::Generator> gen_,
const c10::optional<at::Tensor> &s_aux_);
static void *resolve_symbol(const char *name) {
void *sym = dlsym(RTLD_DEFAULT, name);
if (sym) {
return sym;
}
throw std::runtime_error(
std::string("flash_attn symbol not found: ") + name +
". Ensure flash_attn_2_cuda is loaded before calling this function "
"(e.g. import torch; import flash_attn_2_cuda).");
}
// ---------------------------------------------------------------------------
// Wrappers in the flash:: namespace.
// These match the signatures declared in
// include/infinicore/adaptor/flash_attention_adaptor.hpp
// and bridge the namespace gap between InfiniCore and the DCU library.
// ---------------------------------------------------------------------------
namespace flash {
std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q,
const at::Tensor &kcache,
const at::Tensor &vcache,
std::optional<const at::Tensor> &k_,
std::optional<const at::Tensor> &v_,
std::optional<const at::Tensor> &seqlens_k_,
std::optional<const at::Tensor> &rotary_cos_,
std::optional<const at::Tensor> &rotary_sin_,
std::optional<const at::Tensor> &cache_batch_idx_,
std::optional<const at::Tensor> &leftpad_k_,
std::optional<at::Tensor> &block_table_,
std::optional<at::Tensor> &alibi_slopes_,
std::optional<at::Tensor> &out_,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved,
int num_splits) {
static auto fn = reinterpret_cast<mha_fwd_kvcache_fn_t>(
resolve_symbol("mha_fwd_kvcache"));
c10::optional<at::Tensor> s_aux = c10::nullopt;
return fn(
q, kcache, vcache,
k_, v_, seqlens_k_,
rotary_cos_, rotary_sin_, cache_batch_idx_, leftpad_k_,
block_table_, alibi_slopes_, out_,
softmax_scale, is_causal,
window_size_left, window_size_right,
softcap, is_rotary_interleaved, num_splits, s_aux);
}
std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
std::optional<at::Tensor> &out_,
const at::Tensor &cu_seqlens_q,
const at::Tensor &cu_seqlens_k,
std::optional<at::Tensor> &seqused_k,
std::optional<const at::Tensor> &leftpad_k_,
std::optional<at::Tensor> &block_table_,
std::optional<at::Tensor> &alibi_slopes_,
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
static auto fn = reinterpret_cast<mha_varlen_fwd_fn_t>(
resolve_symbol("mha_varlen_fwd"));
c10::optional<at::Tensor> q_descale = c10::nullopt;
c10::optional<at::Tensor> k_descale = c10::nullopt;
c10::optional<at::Tensor> v_descale = c10::nullopt;
c10::optional<at::Tensor> s_aux = c10::nullopt;
return fn(
q, k, v, out_,
cu_seqlens_q, cu_seqlens_k,
seqused_k, leftpad_k_, block_table_, alibi_slopes_,
max_seqlen_q, max_seqlen_k,
p_dropout, softmax_scale, zero_tensors, is_causal,
window_size_left, window_size_right,
softcap, return_softmax,
q_descale, k_descale, v_descale, gen_, s_aux);
}
} // namespace flash
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "../utils.hpp" #include "../utils.hpp"
#include "infinicore/context/context.hpp" #include "infinicore/context/context.hpp"
#include <infinirt.h> #include <infinirt.h>
#include <spdlog/spdlog.h>
namespace infinicore::graph { namespace infinicore::graph {
...@@ -32,9 +33,11 @@ DispatchableGraphOperator::~DispatchableGraphOperator() { ...@@ -32,9 +33,11 @@ DispatchableGraphOperator::~DispatchableGraphOperator() {
* ========================= */ * ========================= */
struct Graph::DeviceGraph { struct Graph::DeviceGraph {
infinirtGraph_t graph; infinirtGraph_t graph = nullptr;
infinirtGraphExec_t exec; infinirtGraphExec_t exec = nullptr;
infinirtGraphNode_t node; infinirtGraphNode_t node = nullptr;
infinirtStream_t capture_stream = nullptr;
Device capture_device;
std::vector<char> log_buffer; std::vector<char> log_buffer;
DeviceGraph() { DeviceGraph() {
...@@ -51,7 +54,11 @@ struct Graph::DeviceGraph { ...@@ -51,7 +54,11 @@ struct Graph::DeviceGraph {
} }
void launch() { void launch() {
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, context::getStream())); // Ensure we are on the correct device before launching the graph
if (capture_device != context::getDevice()) {
context::setDevice(capture_device);
}
INFINICORE_CHECK_ERROR(infinirtGraphLuanch(exec, capture_stream));
} }
}; };
...@@ -76,29 +83,41 @@ void Graph::instantiate() { ...@@ -76,29 +83,41 @@ void Graph::instantiate() {
// Reset device graph // Reset device graph
device_graph_ = std::make_unique<DeviceGraph>(); device_graph_ = std::make_unique<DeviceGraph>();
// warmup // Save the current stream and device — all graph operations must use this stream
auto capture_stream = context::getStream();
auto capture_device = context::getDevice();
// warmup: ensure we are on the correct device and stream
context::setDevice(capture_device);
for (size_t iter = 0; iter < 5; ++iter) { for (size_t iter = 0; iter < 5; ++iter) {
this->run(); this->run();
} }
infinicore::context::syncStream(); infinicore::context::syncStream();
// Ensure device is correct before capture (may have been switched during warmup)
context::setDevice(capture_device);
if (infinirtStreamBeginCapture( if (infinirtStreamBeginCapture(
context::getStream(), capture_stream,
INFINIRT_STREAM_CAPTURE_MODE_RELAXED) INFINIRT_STREAM_CAPTURE_MODE_RELAXED)
!= INFINI_STATUS_SUCCESS) { != INFINI_STATUS_SUCCESS) {
return; return;
} }
// Run and record // Run and record — all ops must use capture_stream
this->run(); this->run();
if (infinirtStreamEndCapture( if (infinirtStreamEndCapture(
context::getStream(), capture_stream,
&device_graph_.get()->graph) &device_graph_.get()->graph)
!= INFINI_STATUS_SUCCESS) { != INFINI_STATUS_SUCCESS) {
return; return;
} }
// Save the capture stream and device for later launch()
device_graph_.get()->capture_stream = capture_stream;
device_graph_.get()->capture_device = capture_device;
if (infinirtGraphInstantiate( if (infinirtGraphInstantiate(
&device_graph_.get()->exec, &device_graph_.get()->exec,
device_graph_.get()->graph, device_graph_.get()->graph,
......
...@@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings, ...@@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings,
Tensor Embedding::forward(const Tensor &indices) const { Tensor Embedding::forward(const Tensor &indices) const {
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto device_type = device_.getType(); auto device_type = device_.getType();
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI) { if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE || device_type == Device::Type::ALI|| device_type == Device::Type::HYGON) {
// Use op::embedding which supports device-side input and batch dimension // Use op::embedding which supports device-side input and batch dimension
return op::embedding(indices->contiguous()->to(device_), weight_); return op::embedding(indices->contiguous()->to(device_), weight_);
} }
......
...@@ -31,7 +31,9 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { ...@@ -31,7 +31,9 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
|| device_.getType() == Device::Type::ILUVATAR || device_.getType() == Device::Type::ILUVATAR
|| device_.getType() == Device::Type::METAX || device_.getType() == Device::Type::METAX
|| device_.getType() == Device::Type::MOORE || device_.getType() == Device::Type::MOORE
|| device_.getType() == Device::Type::ALI) { || device_.getType() == Device::Type::ALI
|| device_.getType() == Device::Type::HYGON) {
// ){
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_)); op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
} else { } else {
op::add_(residual, x, residual); op::add_(residual, x, residual);
......
...@@ -33,7 +33,7 @@ void *plan(Tensor out, ...@@ -33,7 +33,7 @@ void *plan(Tensor out,
void run(void *planned_meta) { void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN #ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
......
...@@ -41,18 +41,25 @@ void *plan(Tensor out, ...@@ -41,18 +41,25 @@ void *plan(Tensor out,
void run(void *planned_meta) { void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN #ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); infinicore::adaptor::TorchStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q); auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k); auto k = infinicore::adaptor::to_aten_tensor(p->k).contiguous();
auto v = infinicore::adaptor::to_aten_tensor(p->v); auto v = infinicore::adaptor::to_aten_tensor(p->v).contiguous();
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out)); auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
// Flash-attn requires cu_seqlens and block_table on same device as q/k/v.
auto device = q.device();
if (!cu_seqlens_q.is_cuda()) cu_seqlens_q = cu_seqlens_q.to(device);
if (!cu_seqlens_kv.is_cuda()) cu_seqlens_kv = cu_seqlens_kv.to(device);
if (block_table.has_value() && !block_table->is_cuda()) block_table = block_table->to(device);
std::optional<at::Tensor> seqused_k = std::nullopt; std::optional<at::Tensor> seqused_k = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt; std::optional<const at::Tensor> leftpad_k = std::nullopt;
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto max_seqlen_q = p->max_seqlen_q; auto max_seqlen_q = p->max_seqlen_q;
auto max_seqlen_k = p->max_seqlen_k; auto max_seqlen_k = p->max_seqlen_k;
auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt;
......
...@@ -60,4 +60,283 @@ __device__ void add_rmsnormBlock( ...@@ -60,4 +60,283 @@ __device__ void add_rmsnormBlock(
} }
} }
// dim=4096, block=1024 => 4 elements per thread: full unroll + register-held sums (no 2nd read of residual_out).
template <typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock_dim4096_bs1024(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
float epsilon) {
constexpr unsigned int BS = 1024;
constexpr size_t DIM = 4096;
const size_t batch_idx = blockIdx.x / nhead;
const size_t head_idx = blockIdx.x % nhead;
Tdata *y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
const Tdata *a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
const Tdata *b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
const Tweight *w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
const unsigned int t = threadIdx.x;
Tcompute s0 = Tcompute(a_ptr[t]) + Tcompute(b_ptr[t]);
Tcompute s1 = Tcompute(a_ptr[t + BS]) + Tcompute(b_ptr[t + BS]);
Tcompute s2 = Tcompute(a_ptr[t + 2 * BS]) + Tcompute(b_ptr[t + 2 * BS]);
Tcompute s3 = Tcompute(a_ptr[t + 3 * BS]) + Tcompute(b_ptr[t + 3 * BS]);
residual_out_ptr[t] = Tdata(s0);
residual_out_ptr[t + BS] = Tdata(s1);
residual_out_ptr[t + 2 * BS] = Tdata(s2);
residual_out_ptr[t + 3 * BS] = Tdata(s3);
Tcompute sum_squared = s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3;
using BlockReduce = cub::BlockReduce<Tcompute, BS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
__shared__ Tcompute rms;
if (t == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(DIM) + epsilon));
}
__syncthreads();
y_ptr[t] = Tdata(s0 * Tcompute(w_ptr[t]) * rms);
y_ptr[t + BS] = Tdata(s1 * Tcompute(w_ptr[t + BS]) * rms);
y_ptr[t + 2 * BS] = Tdata(s2 * Tcompute(w_ptr[t + 2 * BS]) * rms);
y_ptr[t + 3 * BS] = Tdata(s3 * Tcompute(w_ptr[t + 3 * BS]) * rms);
}
// dim=8192, block=1024 => 8 elements per thread: full unroll + register-held sums (no 2nd read of residual_out).
template <typename Tcompute, typename Tdata, typename Tweight>
__device__ void add_rmsnormBlock_dim8192_bs1024(
Tdata *__restrict__ y,
Tdata *__restrict__ residual_out,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
ptrdiff_t stride_residual_out_batch,
ptrdiff_t stride_residual_out_nhead,
const Tdata *__restrict__ a,
ptrdiff_t stride_a_batch,
ptrdiff_t stride_a_nhead,
const Tdata *__restrict__ b,
ptrdiff_t stride_b_batch,
ptrdiff_t stride_b_nhead,
const Tweight *__restrict__ w,
size_t nhead,
float epsilon) {
constexpr unsigned int BS = 1024;
constexpr size_t DIM = 8192;
const size_t batch_idx = blockIdx.x / nhead;
const size_t head_idx = blockIdx.x % nhead;
Tdata *y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
const Tdata *a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
const Tdata *b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
const Tweight *w_ptr = w;
Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
const unsigned int t = threadIdx.x;
Tcompute s0 = Tcompute(a_ptr[t]) + Tcompute(b_ptr[t]);
Tcompute s1 = Tcompute(a_ptr[t + BS]) + Tcompute(b_ptr[t + BS]);
Tcompute s2 = Tcompute(a_ptr[t + 2 * BS]) + Tcompute(b_ptr[t + 2 * BS]);
Tcompute s3 = Tcompute(a_ptr[t + 3 * BS]) + Tcompute(b_ptr[t + 3 * BS]);
Tcompute s4 = Tcompute(a_ptr[t + 4 * BS]) + Tcompute(b_ptr[t + 4 * BS]);
Tcompute s5 = Tcompute(a_ptr[t + 5 * BS]) + Tcompute(b_ptr[t + 5 * BS]);
Tcompute s6 = Tcompute(a_ptr[t + 6 * BS]) + Tcompute(b_ptr[t + 6 * BS]);
Tcompute s7 = Tcompute(a_ptr[t + 7 * BS]) + Tcompute(b_ptr[t + 7 * BS]);
residual_out_ptr[t] = Tdata(s0);
residual_out_ptr[t + BS] = Tdata(s1);
residual_out_ptr[t + 2 * BS] = Tdata(s2);
residual_out_ptr[t + 3 * BS] = Tdata(s3);
residual_out_ptr[t + 4 * BS] = Tdata(s4);
residual_out_ptr[t + 5 * BS] = Tdata(s5);
residual_out_ptr[t + 6 * BS] = Tdata(s6);
residual_out_ptr[t + 7 * BS] = Tdata(s7);
Tcompute sum_squared =
s0 * s0 + s1 * s1 + s2 * s2 + s3 * s3 + s4 * s4 + s5 * s5 + s6 * s6 + s7 * s7;
using BlockReduce = cub::BlockReduce<Tcompute, BS>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
__shared__ Tcompute rms;
if (t == 0) {
rms = Tcompute(rsqrtf(sum_squared / Tcompute(DIM) + epsilon));
}
__syncthreads();
y_ptr[t] = Tdata(s0 * Tcompute(w_ptr[t]) * rms);
y_ptr[t + BS] = Tdata(s1 * Tcompute(w_ptr[t + BS]) * rms);
y_ptr[t + 2 * BS] = Tdata(s2 * Tcompute(w_ptr[t + 2 * BS]) * rms);
y_ptr[t + 3 * BS] = Tdata(s3 * Tcompute(w_ptr[t + 3 * BS]) * rms);
y_ptr[t + 4 * BS] = Tdata(s4 * Tcompute(w_ptr[t + 4 * BS]) * rms);
y_ptr[t + 5 * BS] = Tdata(s5 * Tcompute(w_ptr[t + 5 * BS]) * rms);
y_ptr[t + 6 * BS] = Tdata(s6 * Tcompute(w_ptr[t + 6 * BS]) * rms);
y_ptr[t + 7 * BS] = Tdata(s7 * Tcompute(w_ptr[t + 7 * BS]) * rms);
}
#endif #endif
//////////////////////////////////////////////////////////////////////
// #ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
// #define __ADD_RMS_NORM_CUDA_KERNEL_H__
// // 移除 cub 头文件依赖
// // #include <cub/block/block_reduce.cuh>
// template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
// __device__ void add_rmsnormBlock(
// Tdata * y, // 【修复 1】移除 __restrict__ 以支持 In-place
// Tdata * residual_out, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_y_batch,
// ptrdiff_t stride_y_nhead,
// ptrdiff_t stride_residual_out_batch,
// ptrdiff_t stride_residual_out_nhead,
// const Tdata * a, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_a_batch,
// ptrdiff_t stride_a_nhead,
// const Tdata * b, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_b_batch,
// ptrdiff_t stride_b_nhead,
// const Tweight *__restrict__ w, // 权重不被修改,保留 __restrict__ 是安全的
// size_t nhead,
// size_t dim,
// float epsilon) {
// size_t batch_idx = blockIdx.x / nhead;
// size_t head_idx = blockIdx.x % nhead;
// auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
// auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
// auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
// auto w_ptr = w;
// Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
// Tcompute sum_squared = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
// residual_out_ptr[i] = Tdata(sum_val); // Store add result
// sum_squared += sum_val * sum_val;
// }
// // 【修复 2】使用通用且安全的 Shared Memory 手动规约替换 cub::BlockReduce
// // 这样不会受制于特定设备的 Warp Size 差异导致死锁
// __shared__ Tcompute shared_sum[BLOCK_SIZE];
// shared_sum[threadIdx.x] = sum_squared;
// __syncthreads();
// #pragma unroll
// for (unsigned int offset = BLOCK_SIZE / 2; offset > 0; offset /= 2) {
// if (threadIdx.x < offset) {
// shared_sum[threadIdx.x] += shared_sum[threadIdx.x + offset];
// }
// __syncthreads();
// }
// sum_squared = shared_sum[0];
// __shared__ Tcompute rms;
// if (threadIdx.x == 0) {
// rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
// }
// __syncthreads();
// // 重新利用算出的 residual_out
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(residual_out_ptr[i]);
// y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
// }
// }
// #endif
////////////////////////////////////////////////////////////////////////////
// #ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
// #define __ADD_RMS_NORM_CUDA_KERNEL_H__
// #include <cub/block/block_reduce.cuh>
// // 假设每个线程最多处理的元素个数。
// // 例如 70B dim=8192, BLOCK_SIZE=1024,只需 8 个。设为 16 绝对够用。
// #define MAX_ELEMS_PER_THREAD 16
// template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
// __device__ void add_rmsnormBlock(
// Tdata *__restrict__ y,
// Tdata *__restrict__ residual_out,
// ptrdiff_t stride_y_batch,
// ptrdiff_t stride_y_seq, // 🌟 修正命名:通常是按 seq_len 划分,而不是 nhead
// ptrdiff_t stride_residual_out_batch,
// ptrdiff_t stride_residual_out_seq,
// const Tdata *__restrict__ a,
// ptrdiff_t stride_a_batch,
// ptrdiff_t stride_a_seq,
// const Tdata *__restrict__ b,
// ptrdiff_t stride_b_batch,
// ptrdiff_t stride_b_seq,
// const Tweight *__restrict__ w,
// size_t seq_len, // 🌟 修正命名:取代 nhead
// size_t dim,
// float epsilon) {
// // 🌟 一个 Block 处理一个 Token
// size_t batch_idx = blockIdx.x / seq_len;
// size_t seq_idx = blockIdx.x % seq_len;
// auto y_ptr = y + batch_idx * stride_y_batch + seq_idx * stride_y_seq;
// auto a_ptr = a + batch_idx * stride_a_batch + seq_idx * stride_a_seq;
// auto b_ptr = b + batch_idx * stride_b_batch + seq_idx * stride_b_seq;
// Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + seq_idx * stride_residual_out_seq;
// Tcompute sum_squared = 0;
// // 🌟 真融合核心:用寄存器数组缓存当前线程计算的加法结果!
// Tcompute thread_cache[MAX_ELEMS_PER_THREAD];
// int cache_idx = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
// residual_out_ptr[i] = Tdata(sum_val); // 依然写回全局显存供后续 Attention 使用
// thread_cache[cache_idx++] = sum_val; // 🌟 同时保存在极速寄存器中!
// sum_squared += sum_val * sum_val;
// }
// // Block 内规约求平方和
// using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
// __shared__ typename BlockReduce::TempStorage temp_storage;
// sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
// __shared__ Tcompute rms;
// if (threadIdx.x == 0) {
// rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
// }
// __syncthreads();
// // 🌟 第二阶段:直接从寄存器 `thread_cache` 读取,彻底干掉那次致命的显存读取!
// cache_idx = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// // 使用 __ldg (如果框架支持) 读取公共权重,速度拉满
// Tcompute weight_val = Tcompute(__ldg(&w[i]));
// y_ptr[i] = Tdata(thread_cache[cache_idx++] * weight_val * rms);
// }
// }
// #endif
\ No newline at end of file
...@@ -79,7 +79,8 @@ __device__ __forceinline__ float warpReduceMax(float x) { ...@@ -79,7 +79,8 @@ __device__ __forceinline__ float warpReduceMax(float x) {
} }
__device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) { __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
#if defined(ENABLE_ILUVATAR_API) #if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
// Iluvatar and Hygon DCU (HIP): use raw pointer cast instead of CUDA intrinsic.
return static_cast<unsigned int>(reinterpret_cast<uintptr_t>(ptr)); return static_cast<unsigned int>(reinterpret_cast<uintptr_t>(ptr));
#else #else
return static_cast<unsigned int>(__cvta_generic_to_shared(ptr)); return static_cast<unsigned int>(__cvta_generic_to_shared(ptr));
...@@ -87,7 +88,8 @@ __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) { ...@@ -87,7 +88,8 @@ __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
} }
__device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) { __device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const void *src_global) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) // cp.async is NVIDIA PTX-only; Hygon DCU (HIP) must use plain loads instead.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
const unsigned int dst = cvtaToShared(dst_shared); const unsigned int dst = cvtaToShared(dst_shared);
asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" ::"r"(dst), "l"(src_global)); asm volatile("cp.async.ca.shared.global [%0], [%1], 16;\n" ::"r"(dst), "l"(src_global));
#else #else
...@@ -98,14 +100,14 @@ __device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const ...@@ -98,14 +100,14 @@ __device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const
} }
__device__ __forceinline__ void cpAsyncCommit() { __device__ __forceinline__ void cpAsyncCommit() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
asm volatile("cp.async.commit_group;\n" ::); asm volatile("cp.async.commit_group;\n" ::);
#endif #endif
} }
template <int N> template <int N>
__device__ __forceinline__ void cpAsyncWaitGroup() { __device__ __forceinline__ void cpAsyncWaitGroup() {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#endif #endif
} }
...@@ -113,7 +115,7 @@ __device__ __forceinline__ void cpAsyncWaitGroup() { ...@@ -113,7 +115,7 @@ __device__ __forceinline__ void cpAsyncWaitGroup() {
// cp.async.wait_group requires a compile-time immediate, so for small fixed // cp.async.wait_group requires a compile-time immediate, so for small fixed
// stage counts we provide a tiny runtime switch. // stage counts we provide a tiny runtime switch.
__device__ __forceinline__ void cpAsyncWaitGroupRt(int n) { __device__ __forceinline__ void cpAsyncWaitGroupRt(int n) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
if (n <= 0) { if (n <= 0) {
cpAsyncWaitGroup<0>(); cpAsyncWaitGroup<0>();
} else if (n == 1) { } else if (n == 1) {
...@@ -1143,8 +1145,7 @@ __device__ void flashAttentionDecodeCtaPipelinedKernel( ...@@ -1143,8 +1145,7 @@ __device__ void flashAttentionDecodeCtaPipelinedKernel(
// Prefetch the very first token. // Prefetch the very first token.
int buf = 0; int buf = 0;
int t_base = 0; (void)0; // t_base, token_in_block removed (unused)
int token_in_block = 0;
int logical_block = 0; int logical_block = 0;
{ {
if (tid == 0) { if (tid == 0) {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/paged_attention.h" #include "infiniop/ops/paged_attention.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/paged_attention_nvidia.cuh" #include "nvidia/paged_attention_nvidia.cuh"
#endif #endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
...@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionDescriptor( ...@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#endif #endif
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia) CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( ...@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#endif #endif
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia) GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -112,6 +118,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttention( ...@@ -112,6 +118,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttention(
#endif #endif
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
...@@ -141,6 +150,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( ...@@ -141,6 +150,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#endif #endif
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia)
#endif #endif
default: default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
......
...@@ -2306,9 +2306,10 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( ...@@ -2306,9 +2306,10 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
} }
__syncthreads(); __syncthreads();
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && !defined(ENABLE_HYGON_API)
// WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows. // WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows.
// For kBlockN=64, only the first 4 warps participate in WMMA score computation. // For kBlockN=64, only the first 4 warps participate in WMMA score computation.
// nvcuda::wmma is NVIDIA-only; HIP/ROCm does not support it.
namespace wmma = nvcuda::wmma; namespace wmma = nvcuda::wmma;
constexpr int kNSub = kBlockN / 16; constexpr int kNSub = kBlockN / 16;
if (warp_id < kNSub) { if (warp_id < kNSub) {
......
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cstdint> #include <cstdint>
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h" #include "infiniop/ops/paged_attention_prefill.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/paged_attention_prefill_nvidia.cuh" #include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
...@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( ...@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia) CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore) CREATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
...@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( ...@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia) GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore) GET(INFINI_DEVICE_MOORE, moore)
#endif #endif
...@@ -115,6 +121,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttentionPrefill( ...@@ -115,6 +121,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttentionPrefill(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore) CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
...@@ -144,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( ...@@ -144,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore) DESTROY(INFINI_DEVICE_MOORE, moore)
#endif #endif
......
...@@ -94,6 +94,21 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info, ...@@ -94,6 +94,21 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_cache_slot_stride, k_cache_slot_stride,
v_cache_slot_stride); v_cache_slot_stride);
} else if (dtype == INFINI_DTYPE_BF16) { } else if (dtype == INFINI_DTYPE_BF16) {
std::cout<< "NUM_THREADS: " << NUM_THREADS << std::endl;
std::cout<< "grid: " << grid.x << ", " << grid.y << ", " << grid.z << std::endl;
std::cout<< "block: " << block.x << ", " << block.y << ", " << block.z << std::endl;
std::cout<< "shared_mem_size: " << shared_mem_size << std::endl;
std::cout<< "slot_mapping: " << slot_mapping << std::endl;
std::cout<< "head_size: " << head_size << std::endl;
std::cout<< "block_size: " << block_size << std::endl;
std::cout<< "k_src_stride: " << k_src_stride << std::endl;
std::cout<< "v_src_stride: " << v_src_stride << std::endl;
std::cout<< "k_cache_block_stride: " << k_cache_block_stride << std::endl;
std::cout<< "v_cache_block_stride: " << v_cache_block_stride << std::endl;
std::cout<< "k_cache_head_stride: " << k_cache_head_stride << std::endl;
std::cout<< "v_cache_head_stride: " << v_cache_head_stride << std::endl;
std::cout<< "k_cache_slot_stride: " << k_cache_slot_stride << std::endl;
std::cout<< "v_cache_slot_stride: " << v_cache_slot_stride << std::endl;
pagedCaching<__nv_bfloat16, NUM_THREADS> pagedCaching<__nv_bfloat16, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
(__nv_bfloat16 *)k_cache, (__nv_bfloat16 *)k_cache,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/paged_caching.h" #include "infiniop/ops/paged_caching.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
#include "nvidia/paged_caching_nvidia.cuh" #include "nvidia/paged_caching_nvidia.cuh"
#endif #endif
#ifdef ENABLE_METAX_API #ifdef ENABLE_METAX_API
...@@ -41,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedCachingDescriptor( ...@@ -41,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, nvidia) CREATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
CREATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore) CREATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
...@@ -71,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( ...@@ -71,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
GET(INFINI_DEVICE_ILUVATAR, nvidia) GET(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
GET(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore) GET(INFINI_DEVICE_MOORE, moore)
#endif #endif
...@@ -105,6 +111,9 @@ __INFINI_C infiniStatus_t infiniopPagedCaching( ...@@ -105,6 +111,9 @@ __INFINI_C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia) CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
CALCULATE(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
CALCULATE(INFINI_DEVICE_MOORE, moore) CALCULATE(INFINI_DEVICE_MOORE, moore)
#endif #endif
...@@ -134,6 +143,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedCachingDescriptor( ...@@ -134,6 +143,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_ILUVATAR_API #ifdef ENABLE_ILUVATAR_API
DESTROY(INFINI_DEVICE_ILUVATAR, nvidia) DESTROY(INFINI_DEVICE_ILUVATAR, nvidia)
#endif #endif
#ifdef ENABLE_HYGON_API
DESTROY(INFINI_DEVICE_HYGON, nvidia)
#endif
#ifdef ENABLE_MOORE_API #ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore) DESTROY(INFINI_DEVICE_MOORE, moore)
#endif #endif
......
...@@ -103,4 +103,50 @@ __device__ void ropeThreadPerItemBlock( ...@@ -103,4 +103,50 @@ __device__ void ropeThreadPerItemBlock(
} }
} }
// grid_dim = dim3(info.seqlen, info.batch, 1);
// dim3 block_dim = dim3(info.table_dim, info.nhead, 1);
template <bool IsGPTJ, typename Tindex, typename Tangle>
__device__ void customropeThreadPerItemBlock(
cuda_bfloat16 *y_,
const cuda_bfloat16 *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids (0 if 1D)
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
const size_t batch_idx = blockIdx.y;
const size_t seq_idx = blockIdx.x;
const size_t head_idx = threadIdx.y;
const size_t dim_idx = threadIdx.x;
auto y_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
auto x_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
size_t pos_offset = batch_idx * pos_stride_batch + seq_idx;
size_t pos_id = size_t(pos_ids[pos_offset]);
auto table_offset = pos_id * table_dim;
Tangle sin__ = sin_table[table_offset + dim_idx];
Tangle cos__ = cos_table[table_offset + dim_idx];
size_t pos0 = dim_idx;
size_t pos1 = dim_idx + table_dim;
Tangle x0 = __bfloat162float(x_[x_offset + pos0]);
Tangle x1 = __bfloat162float(x_[x_offset + pos1]);
Tangle y0 = x0 * cos__ - x1 * sin__;
Tangle y1 = x0 * sin__ + x1 * cos__;
y_[y_offset + pos0] = __float2bfloat16(y0);
y_[y_offset + pos1] = __float2bfloat16(y1);
}
#endif #endif
...@@ -33,6 +33,34 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel( ...@@ -33,6 +33,34 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
x_stride_batch, x_stride_seqlen, x_stride_nhead); x_stride_batch, x_stride_seqlen, x_stride_nhead);
} }
template <bool IsGPTJ, typename Tindex, typename Tangle>
INFINIOP_CUDA_KERNEL customropeThreadPerItemKernel(
cuda_bfloat16 *y_,
const cuda_bfloat16 *x_,
const Tindex *__restrict__ pos_ids,
const Tangle *__restrict__ sin_table,
const Tangle *__restrict__ cos_table,
size_t table_dim,
size_t pos_stride_batch, // Stride for batch dimension in pos_ids
bool pos_has_batch_dim, // Whether pos_ids has batch dimension
bool has_batch_dim, // Whether tensors have batch dimension
ptrdiff_t y_stride_batch,
ptrdiff_t y_stride_seqlen,
ptrdiff_t y_stride_nhead,
ptrdiff_t x_stride_batch,
ptrdiff_t x_stride_seqlen,
ptrdiff_t x_stride_nhead) {
customropeThreadPerItemBlock<IsGPTJ>(
y_, x_, pos_ids,
sin_table, cos_table,
table_dim,
pos_stride_batch,
pos_has_batch_dim,
has_batch_dim,
y_stride_batch, y_stride_seqlen, y_stride_nhead,
x_stride_batch, x_stride_seqlen, x_stride_nhead);
}
namespace op::rope::nvidia { namespace op::rope::nvidia {
struct Descriptor::Opaque { struct Descriptor::Opaque {
...@@ -96,9 +124,15 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -96,9 +124,15 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
grid_dim = dim3(dimx, dimy, dimz); grid_dim = dim3(dimx, dimy, dimz);
} else { } else {
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1 // 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim = dim3(dimx, dimy); grid_dim = dim3(dimx, dimy, 1);
} }
// printf("block_size = %d info.table_dim = %ld has_batch_dim: %d, is_gpt_j: %d pos_has_batch_dim: %d\n",
// block_size, info.table_dim, info.has_batch_dim, is_gpt_j, info.pos_has_batch_dim);
// [batch, seqlen, nhead, dhead, table_len, table_dim, y_stride_batch, y_stride_seqlen, y_stride_nhead, x_stride_batch, x_stride_seqlen,x_stride_nhead]
// printf("[%ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld]\n", info.batch,
// info.seqlen, info.nhead, info.dhead, info.table_len, info.table_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
if (is_gpt_j) { if (is_gpt_j) {
ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>( ropeThreadPerItemKernel<true><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim, y, x, pos_ids, sin_table, cos_table, info.table_dim,
...@@ -107,6 +141,27 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -107,6 +141,27 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
info.has_batch_dim, info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead, info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead); info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else {
if ((std::is_same<Tdata, cuda_bfloat16>::value) && (info.table_dim == 64) && (info.nhead < 16) && (info.seqlen < 505)) {
auto bf16_y = reinterpret_cast<cuda_bfloat16*>(y);
auto bf16_x = reinterpret_cast<const cuda_bfloat16*>(x);
grid_dim = dim3(info.seqlen, info.batch, 1);
dim3 block_dim = dim3(info.table_dim, info.nhead, 1);
customropeThreadPerItemKernel<false><<<grid_dim, block_dim, 0, stream>>>(
bf16_y, bf16_x, pos_ids, sin_table, cos_table, info.table_dim,
pos_stride_batch,
info.pos_has_batch_dim,
info.has_batch_dim,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
// ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
// y, x, pos_ids, sin_table, cos_table, info.table_dim,
// pos_stride_batch,
// info.pos_has_batch_dim,
// info.has_batch_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} else { } else {
ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>( ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
y, x, pos_ids, sin_table, cos_table, info.table_dim, y, x, pos_ids, sin_table, cos_table, info.table_dim,
...@@ -116,6 +171,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, ...@@ -116,6 +171,7 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead, info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead); info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
} }
}
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment