Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5de45ee6
Commit
5de45ee6
authored
Apr 21, 2026
by
yaoht
Browse files
接入fa,适配dcu,优化addrmsnorm和rope算子
parent
93191613
Pipeline
#3510
failed with stages
in 0 seconds
Changes
23
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1069 additions
and
43 deletions
+1069
-43
include/infinicore/adaptor/aten_adaptor.hpp
include/infinicore/adaptor/aten_adaptor.hpp
+14
-4
scripts/build.sh
scripts/build.sh
+12
-0
src/infinicore/adaptor/aten_adaptor.cc
src/infinicore/adaptor/aten_adaptor.cc
+7
-2
src/infinicore/adaptor/flash_attn_hygon_wrapper.cc
src/infinicore/adaptor/flash_attn_hygon_wrapper.cc
+162
-0
src/infinicore/graph/graph.cc
src/infinicore/graph/graph.cc
+27
-8
src/infinicore/nn/embedding.cc
src/infinicore/nn/embedding.cc
+1
-1
src/infinicore/nn/rmsnorm.cc
src/infinicore/nn/rmsnorm.cc
+3
-1
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
+1
-1
src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc
...e/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc
+11
-4
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
+279
-0
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
+376
-1
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
+8
-7
src/infiniop/ops/paged_attention/operator.cc
src/infiniop/ops/paged_attention/operator.cc
+13
-1
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
+2
-1
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
...ttention_prefill/nvidia/paged_attention_prefill_nvidia.cu
+1
-1
src/infiniop/ops/paged_attention_prefill/operator.cc
src/infiniop/ops/paged_attention_prefill/operator.cc
+13
-1
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
...infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
+15
-0
src/infiniop/ops/paged_caching/operator.cc
src/infiniop/ops/paged_caching/operator.cc
+13
-1
src/infiniop/ops/rope/cuda/kernel.cuh
src/infiniop/ops/rope/cuda/kernel.cuh
+46
-0
src/infiniop/ops/rope/nvidia/rope_nvidia.cu
src/infiniop/ops/rope/nvidia/rope_nvidia.cu
+65
-9
No files found.
include/infinicore/adaptor/aten_adaptor.hpp
View file @
5de45ee6
...
...
@@ -5,9 +5,12 @@
#include <ATen/ATen.h>
#ifdef
ENABLE_NVIDIA_API
#if
def
ined(
ENABLE_NVIDIA_API
)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#elif defined(ENABLE_HYGON_API)
#include <ATen/hip/HIPContext.h>
#include <c10/hip/HIPGuard.h>
#endif
namespace
infinicore
::
adaptor
{
...
...
@@ -29,7 +32,8 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
}
inline
at
::
Device
to_at_device
(
const
Device
&
device
)
{
if
(
device
.
getType
()
==
Device
::
Type
::
NVIDIA
)
{
if
(
device
.
getType
()
==
Device
::
Type
::
NVIDIA
||
device
.
getType
()
==
Device
::
Type
::
HYGON
)
{
return
at
::
Device
(
at
::
kCUDA
,
device
.
getIndex
());
}
else
if
(
device
.
getType
()
==
Device
::
Type
::
CPU
)
{
return
at
::
Device
(
at
::
kCPU
);
...
...
@@ -40,8 +44,14 @@ inline at::Device to_at_device(const Device &device) {
at
::
Tensor
to_aten_tensor
(
const
infinicore
::
Tensor
&
t
);
#ifdef ENABLE_NVIDIA_API
c10
::
cuda
::
CUDAStream
get_cuda_stream
();
#if defined(ENABLE_HYGON_API)
using
TorchStream
=
c10
::
hip
::
HIPStream
;
using
TorchStreamGuard
=
c10
::
hip
::
HIPStreamGuard
;
TorchStream
get_cuda_stream
();
#elif defined(ENABLE_NVIDIA_API)
using
TorchStream
=
c10
::
cuda
::
CUDAStream
;
using
TorchStreamGuard
=
c10
::
cuda
::
CUDAStreamGuard
;
TorchStream
get_cuda_stream
();
#endif
}
// namespace infinicore::adaptor
...
...
scripts/build.sh
0 → 100755
View file @
5de45ee6
#!/bin/bash
export
CUDA_HOME
=
/opt/dtk/cuda/cuda
xmake clean
--all
xmake f
-c
--hygon-dcu
=
y
--ccl
=
y
--graph
=
y
--cuda
=
$CUDA_HOME
--aten
=
y
--flash-attn-prebuilt
=
/usr/local/lib/python3.10/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so
xmake build
&&
xmake
install
xmake build _infinicore
&&
xmake
install
_infinicore
pip
install
-e
.
--no-build-isolation
\ No newline at end of file
src/infinicore/adaptor/aten_adaptor.cc
View file @
5de45ee6
...
...
@@ -32,8 +32,13 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options
);
}
#ifdef ENABLE_NVIDIA_API
c10
::
cuda
::
CUDAStream
get_cuda_stream
()
{
#if defined(ENABLE_HYGON_API)
TorchStream
get_cuda_stream
()
{
return
c10
::
hip
::
getStreamFromExternal
(
hipStream_t
(
infinicore
::
context
::
getStream
()),
infinicore
::
context
::
getDevice
().
getIndex
());
}
#elif defined(ENABLE_NVIDIA_API)
TorchStream
get_cuda_stream
()
{
return
c10
::
cuda
::
getStreamFromExternal
(
cudaStream_t
(
infinicore
::
context
::
getStream
()),
infinicore
::
context
::
getDevice
().
getIndex
());
}
...
...
src/infinicore/adaptor/flash_attn_hygon_wrapper.cc
0 → 100755
View file @
5de45ee6
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include <ATen/ATen.h>
#include <c10/util/Optional.h>
#include <dlfcn.h>
#include <optional>
#include <stdexcept>
#include <vector>
// ---------------------------------------------------------------------------
// Function pointer types for the extern "C" functions exported by the DCU
// flash_attn shared library (built from flash-attention-cutlass-master).
// We resolve these at runtime via dlsym to avoid hard link-time dependency
// on the prebuilt .so (which requires libtorch_python.so).
// ---------------------------------------------------------------------------
using
mha_fwd_kvcache_fn_t
=
std
::
vector
<
at
::
Tensor
>
(
*
)(
at
::
Tensor
&
q
,
const
at
::
Tensor
&
kcache
,
const
at
::
Tensor
&
vcache
,
c10
::
optional
<
const
at
::
Tensor
>
&
k_
,
c10
::
optional
<
const
at
::
Tensor
>
&
v_
,
c10
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
c10
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
c10
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
c10
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
c10
::
optional
<
at
::
Tensor
>
&
block_table_
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
c10
::
optional
<
at
::
Tensor
>
&
out_
,
const
float
softmax_scale
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
bool
is_rotary_interleaved
,
int
num_splits
,
const
c10
::
optional
<
at
::
Tensor
>
&
s_aux_
);
using
mha_varlen_fwd_fn_t
=
std
::
vector
<
at
::
Tensor
>
(
*
)(
at
::
Tensor
&
q
,
const
at
::
Tensor
&
k
,
const
at
::
Tensor
&
v
,
c10
::
optional
<
at
::
Tensor
>
&
out_
,
const
at
::
Tensor
&
cu_seqlens_q
,
const
at
::
Tensor
&
cu_seqlens_k
,
c10
::
optional
<
at
::
Tensor
>
&
seqused_k
,
c10
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
c10
::
optional
<
at
::
Tensor
>
&
block_table_
,
c10
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
bool
zero_tensors
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
const
bool
return_softmax
,
c10
::
optional
<
at
::
Tensor
>
q_descale_
,
c10
::
optional
<
at
::
Tensor
>
k_descale_
,
c10
::
optional
<
at
::
Tensor
>
v_descale_
,
c10
::
optional
<
at
::
Generator
>
gen_
,
const
c10
::
optional
<
at
::
Tensor
>
&
s_aux_
);
static
void
*
resolve_symbol
(
const
char
*
name
)
{
void
*
sym
=
dlsym
(
RTLD_DEFAULT
,
name
);
if
(
sym
)
{
return
sym
;
}
throw
std
::
runtime_error
(
std
::
string
(
"flash_attn symbol not found: "
)
+
name
+
". Ensure flash_attn_2_cuda is loaded before calling this function "
"(e.g. import torch; import flash_attn_2_cuda)."
);
}
// ---------------------------------------------------------------------------
// Wrappers in the flash:: namespace.
// These match the signatures declared in
// include/infinicore/adaptor/flash_attention_adaptor.hpp
// and bridge the namespace gap between InfiniCore and the DCU library.
// ---------------------------------------------------------------------------
namespace
flash
{
std
::
vector
<
at
::
Tensor
>
mha_fwd_kvcache
(
at
::
Tensor
&
q
,
const
at
::
Tensor
&
kcache
,
const
at
::
Tensor
&
vcache
,
std
::
optional
<
const
at
::
Tensor
>
&
k_
,
std
::
optional
<
const
at
::
Tensor
>
&
v_
,
std
::
optional
<
const
at
::
Tensor
>
&
seqlens_k_
,
std
::
optional
<
const
at
::
Tensor
>
&
rotary_cos_
,
std
::
optional
<
const
at
::
Tensor
>
&
rotary_sin_
,
std
::
optional
<
const
at
::
Tensor
>
&
cache_batch_idx_
,
std
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
std
::
optional
<
at
::
Tensor
>
&
block_table_
,
std
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
std
::
optional
<
at
::
Tensor
>
&
out_
,
const
float
softmax_scale
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
bool
is_rotary_interleaved
,
int
num_splits
)
{
static
auto
fn
=
reinterpret_cast
<
mha_fwd_kvcache_fn_t
>
(
resolve_symbol
(
"mha_fwd_kvcache"
));
c10
::
optional
<
at
::
Tensor
>
s_aux
=
c10
::
nullopt
;
return
fn
(
q
,
kcache
,
vcache
,
k_
,
v_
,
seqlens_k_
,
rotary_cos_
,
rotary_sin_
,
cache_batch_idx_
,
leftpad_k_
,
block_table_
,
alibi_slopes_
,
out_
,
softmax_scale
,
is_causal
,
window_size_left
,
window_size_right
,
softcap
,
is_rotary_interleaved
,
num_splits
,
s_aux
);
}
std
::
vector
<
at
::
Tensor
>
mha_varlen_fwd
(
at
::
Tensor
&
q
,
const
at
::
Tensor
&
k
,
const
at
::
Tensor
&
v
,
std
::
optional
<
at
::
Tensor
>
&
out_
,
const
at
::
Tensor
&
cu_seqlens_q
,
const
at
::
Tensor
&
cu_seqlens_k
,
std
::
optional
<
at
::
Tensor
>
&
seqused_k
,
std
::
optional
<
const
at
::
Tensor
>
&
leftpad_k_
,
std
::
optional
<
at
::
Tensor
>
&
block_table_
,
std
::
optional
<
at
::
Tensor
>
&
alibi_slopes_
,
int
max_seqlen_q
,
const
int
max_seqlen_k
,
const
float
p_dropout
,
const
float
softmax_scale
,
const
bool
zero_tensors
,
bool
is_causal
,
int
window_size_left
,
int
window_size_right
,
const
float
softcap
,
const
bool
return_softmax
,
std
::
optional
<
at
::
Generator
>
gen_
)
{
static
auto
fn
=
reinterpret_cast
<
mha_varlen_fwd_fn_t
>
(
resolve_symbol
(
"mha_varlen_fwd"
));
c10
::
optional
<
at
::
Tensor
>
q_descale
=
c10
::
nullopt
;
c10
::
optional
<
at
::
Tensor
>
k_descale
=
c10
::
nullopt
;
c10
::
optional
<
at
::
Tensor
>
v_descale
=
c10
::
nullopt
;
c10
::
optional
<
at
::
Tensor
>
s_aux
=
c10
::
nullopt
;
return
fn
(
q
,
k
,
v
,
out_
,
cu_seqlens_q
,
cu_seqlens_k
,
seqused_k
,
leftpad_k_
,
block_table_
,
alibi_slopes_
,
max_seqlen_q
,
max_seqlen_k
,
p_dropout
,
softmax_scale
,
zero_tensors
,
is_causal
,
window_size_left
,
window_size_right
,
softcap
,
return_softmax
,
q_descale
,
k_descale
,
v_descale
,
gen_
,
s_aux
);
}
}
// namespace flash
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
src/infinicore/graph/graph.cc
View file @
5de45ee6
...
...
@@ -3,6 +3,7 @@
#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include <infinirt.h>
#include <spdlog/spdlog.h>
namespace
infinicore
::
graph
{
...
...
@@ -32,9 +33,11 @@ DispatchableGraphOperator::~DispatchableGraphOperator() {
* ========================= */
struct
Graph
::
DeviceGraph
{
infinirtGraph_t
graph
;
infinirtGraphExec_t
exec
;
infinirtGraphNode_t
node
;
infinirtGraph_t
graph
=
nullptr
;
infinirtGraphExec_t
exec
=
nullptr
;
infinirtGraphNode_t
node
=
nullptr
;
infinirtStream_t
capture_stream
=
nullptr
;
Device
capture_device
;
std
::
vector
<
char
>
log_buffer
;
DeviceGraph
()
{
...
...
@@ -51,7 +54,11 @@ struct Graph::DeviceGraph {
}
void
launch
()
{
INFINICORE_CHECK_ERROR
(
infinirtGraphLuanch
(
exec
,
context
::
getStream
()));
// Ensure we are on the correct device before launching the graph
if
(
capture_device
!=
context
::
getDevice
())
{
context
::
setDevice
(
capture_device
);
}
INFINICORE_CHECK_ERROR
(
infinirtGraphLuanch
(
exec
,
capture_stream
));
}
};
...
...
@@ -76,29 +83,41 @@ void Graph::instantiate() {
// Reset device graph
device_graph_
=
std
::
make_unique
<
DeviceGraph
>
();
// warmup
// Save the current stream and device — all graph operations must use this stream
auto
capture_stream
=
context
::
getStream
();
auto
capture_device
=
context
::
getDevice
();
// warmup: ensure we are on the correct device and stream
context
::
setDevice
(
capture_device
);
for
(
size_t
iter
=
0
;
iter
<
5
;
++
iter
)
{
this
->
run
();
}
infinicore
::
context
::
syncStream
();
// Ensure device is correct before capture (may have been switched during warmup)
context
::
setDevice
(
capture_device
);
if
(
infinirtStreamBeginCapture
(
c
ontext
::
getS
tream
()
,
c
apture_s
tream
,
INFINIRT_STREAM_CAPTURE_MODE_RELAXED
)
!=
INFINI_STATUS_SUCCESS
)
{
return
;
}
// Run and record
// Run and record
— all ops must use capture_stream
this
->
run
();
if
(
infinirtStreamEndCapture
(
c
ontext
::
getS
tream
()
,
c
apture_s
tream
,
&
device_graph_
.
get
()
->
graph
)
!=
INFINI_STATUS_SUCCESS
)
{
return
;
}
// Save the capture stream and device for later launch()
device_graph_
.
get
()
->
capture_stream
=
capture_stream
;
device_graph_
.
get
()
->
capture_device
=
capture_device
;
if
(
infinirtGraphInstantiate
(
&
device_graph_
.
get
()
->
exec
,
device_graph_
.
get
()
->
graph
,
...
...
src/infinicore/nn/embedding.cc
View file @
5de45ee6
...
...
@@ -45,7 +45,7 @@ Embedding::Embedding(size_t num_embeddings,
Tensor
Embedding
::
forward
(
const
Tensor
&
indices
)
const
{
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto
device_type
=
device_
.
getType
();
if
(
device_type
==
Device
::
Type
::
NVIDIA
||
device_type
==
Device
::
Type
::
ILUVATAR
||
device_type
==
Device
::
Type
::
METAX
||
device_type
==
Device
::
Type
::
MOORE
||
device_type
==
Device
::
Type
::
ALI
)
{
if
(
device_type
==
Device
::
Type
::
NVIDIA
||
device_type
==
Device
::
Type
::
ILUVATAR
||
device_type
==
Device
::
Type
::
METAX
||
device_type
==
Device
::
Type
::
MOORE
||
device_type
==
Device
::
Type
::
ALI
||
device_type
==
Device
::
Type
::
HYGON
)
{
// Use op::embedding which supports device-side input and batch dimension
return
op
::
embedding
(
indices
->
contiguous
()
->
to
(
device_
),
weight_
);
}
...
...
src/infinicore/nn/rmsnorm.cc
View file @
5de45ee6
...
...
@@ -31,7 +31,9 @@ void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
||
device_
.
getType
()
==
Device
::
Type
::
ILUVATAR
||
device_
.
getType
()
==
Device
::
Type
::
METAX
||
device_
.
getType
()
==
Device
::
Type
::
MOORE
||
device_
.
getType
()
==
Device
::
Type
::
ALI
)
{
||
device_
.
getType
()
==
Device
::
Type
::
ALI
||
device_
.
getType
()
==
Device
::
Type
::
HYGON
)
{
// ){
op
::
add_rms_norm_inplace
(
x
,
residual
,
weight_
,
static_cast
<
float
>
(
eps_
));
}
else
{
op
::
add_
(
residual
,
x
,
residual
);
...
...
src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
View file @
5de45ee6
...
...
@@ -33,7 +33,7 @@ void *plan(Tensor out,
void
run
(
void
*
planned_meta
)
{
#ifdef ENABLE_FLASH_ATTN
c10
::
cuda
::
CUDA
StreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
infinicore
::
adaptor
::
Torch
StreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
out_tensor
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
);
...
...
src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc
View file @
5de45ee6
...
...
@@ -41,18 +41,25 @@ void *plan(Tensor out,
void
run
(
void
*
planned_meta
)
{
#ifdef ENABLE_FLASH_ATTN
c10
::
cuda
::
CUDA
StreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
infinicore
::
adaptor
::
Torch
StreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
q
);
auto
k
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k
);
auto
v
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v
);
auto
k
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k
)
.
contiguous
()
;
auto
v
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v
)
.
contiguous
()
;
auto
out
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
));
auto
cu_seqlens_q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
cum_seqlens_q
);
auto
cu_seqlens_kv
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
cum_seqlens_k
);
auto
block_table
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
block_table
));
// Flash-attn requires cu_seqlens and block_table on same device as q/k/v.
auto
device
=
q
.
device
();
if
(
!
cu_seqlens_q
.
is_cuda
())
cu_seqlens_q
=
cu_seqlens_q
.
to
(
device
);
if
(
!
cu_seqlens_kv
.
is_cuda
())
cu_seqlens_kv
=
cu_seqlens_kv
.
to
(
device
);
if
(
block_table
.
has_value
()
&&
!
block_table
->
is_cuda
())
block_table
=
block_table
->
to
(
device
);
std
::
optional
<
at
::
Tensor
>
seqused_k
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
leftpad_k
=
std
::
nullopt
;
auto
block_table
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
block_table
));
auto
max_seqlen_q
=
p
->
max_seqlen_q
;
auto
max_seqlen_k
=
p
->
max_seqlen_k
;
auto
alibi_slopes
=
p
->
alibi_slopes
?
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
*
p
->
alibi_slopes
))
:
std
::
nullopt
;
...
...
src/infiniop/ops/add_rms_norm/cuda/kernel.cuh
View file @
5de45ee6
...
...
@@ -60,4 +60,283 @@ __device__ void add_rmsnormBlock(
}
}
// dim=4096, block=1024 => 4 elements per thread: full unroll + register-held sums (no 2nd read of residual_out).
template
<
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
add_rmsnormBlock_dim4096_bs1024
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
constexpr
unsigned
int
BS
=
1024
;
constexpr
size_t
DIM
=
4096
;
const
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
const
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
Tdata
*
y_ptr
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
const
Tdata
*
a_ptr
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
const
Tdata
*
b_ptr
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
const
Tweight
*
w_ptr
=
w
;
Tdata
*
residual_out_ptr
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
const
unsigned
int
t
=
threadIdx
.
x
;
Tcompute
s0
=
Tcompute
(
a_ptr
[
t
])
+
Tcompute
(
b_ptr
[
t
]);
Tcompute
s1
=
Tcompute
(
a_ptr
[
t
+
BS
])
+
Tcompute
(
b_ptr
[
t
+
BS
]);
Tcompute
s2
=
Tcompute
(
a_ptr
[
t
+
2
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
2
*
BS
]);
Tcompute
s3
=
Tcompute
(
a_ptr
[
t
+
3
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
3
*
BS
]);
residual_out_ptr
[
t
]
=
Tdata
(
s0
);
residual_out_ptr
[
t
+
BS
]
=
Tdata
(
s1
);
residual_out_ptr
[
t
+
2
*
BS
]
=
Tdata
(
s2
);
residual_out_ptr
[
t
+
3
*
BS
]
=
Tdata
(
s3
);
Tcompute
sum_squared
=
s0
*
s0
+
s1
*
s1
+
s2
*
s2
+
s3
*
s3
;
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BS
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
__shared__
Tcompute
rms
;
if
(
t
==
0
)
{
rms
=
Tcompute
(
rsqrtf
(
sum_squared
/
Tcompute
(
DIM
)
+
epsilon
));
}
__syncthreads
();
y_ptr
[
t
]
=
Tdata
(
s0
*
Tcompute
(
w_ptr
[
t
])
*
rms
);
y_ptr
[
t
+
BS
]
=
Tdata
(
s1
*
Tcompute
(
w_ptr
[
t
+
BS
])
*
rms
);
y_ptr
[
t
+
2
*
BS
]
=
Tdata
(
s2
*
Tcompute
(
w_ptr
[
t
+
2
*
BS
])
*
rms
);
y_ptr
[
t
+
3
*
BS
]
=
Tdata
(
s3
*
Tcompute
(
w_ptr
[
t
+
3
*
BS
])
*
rms
);
}
// dim=8192, block=1024 => 8 elements per thread: full unroll + register-held sums (no 2nd read of residual_out).
template
<
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
__device__
void
add_rmsnormBlock_dim8192_bs1024
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
constexpr
unsigned
int
BS
=
1024
;
constexpr
size_t
DIM
=
8192
;
const
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
const
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
Tdata
*
y_ptr
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
const
Tdata
*
a_ptr
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
const
Tdata
*
b_ptr
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
const
Tweight
*
w_ptr
=
w
;
Tdata
*
residual_out_ptr
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
const
unsigned
int
t
=
threadIdx
.
x
;
Tcompute
s0
=
Tcompute
(
a_ptr
[
t
])
+
Tcompute
(
b_ptr
[
t
]);
Tcompute
s1
=
Tcompute
(
a_ptr
[
t
+
BS
])
+
Tcompute
(
b_ptr
[
t
+
BS
]);
Tcompute
s2
=
Tcompute
(
a_ptr
[
t
+
2
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
2
*
BS
]);
Tcompute
s3
=
Tcompute
(
a_ptr
[
t
+
3
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
3
*
BS
]);
Tcompute
s4
=
Tcompute
(
a_ptr
[
t
+
4
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
4
*
BS
]);
Tcompute
s5
=
Tcompute
(
a_ptr
[
t
+
5
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
5
*
BS
]);
Tcompute
s6
=
Tcompute
(
a_ptr
[
t
+
6
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
6
*
BS
]);
Tcompute
s7
=
Tcompute
(
a_ptr
[
t
+
7
*
BS
])
+
Tcompute
(
b_ptr
[
t
+
7
*
BS
]);
residual_out_ptr
[
t
]
=
Tdata
(
s0
);
residual_out_ptr
[
t
+
BS
]
=
Tdata
(
s1
);
residual_out_ptr
[
t
+
2
*
BS
]
=
Tdata
(
s2
);
residual_out_ptr
[
t
+
3
*
BS
]
=
Tdata
(
s3
);
residual_out_ptr
[
t
+
4
*
BS
]
=
Tdata
(
s4
);
residual_out_ptr
[
t
+
5
*
BS
]
=
Tdata
(
s5
);
residual_out_ptr
[
t
+
6
*
BS
]
=
Tdata
(
s6
);
residual_out_ptr
[
t
+
7
*
BS
]
=
Tdata
(
s7
);
Tcompute
sum_squared
=
s0
*
s0
+
s1
*
s1
+
s2
*
s2
+
s3
*
s3
+
s4
*
s4
+
s5
*
s5
+
s6
*
s6
+
s7
*
s7
;
using
BlockReduce
=
cub
::
BlockReduce
<
Tcompute
,
BS
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
__shared__
Tcompute
rms
;
if
(
t
==
0
)
{
rms
=
Tcompute
(
rsqrtf
(
sum_squared
/
Tcompute
(
DIM
)
+
epsilon
));
}
__syncthreads
();
y_ptr
[
t
]
=
Tdata
(
s0
*
Tcompute
(
w_ptr
[
t
])
*
rms
);
y_ptr
[
t
+
BS
]
=
Tdata
(
s1
*
Tcompute
(
w_ptr
[
t
+
BS
])
*
rms
);
y_ptr
[
t
+
2
*
BS
]
=
Tdata
(
s2
*
Tcompute
(
w_ptr
[
t
+
2
*
BS
])
*
rms
);
y_ptr
[
t
+
3
*
BS
]
=
Tdata
(
s3
*
Tcompute
(
w_ptr
[
t
+
3
*
BS
])
*
rms
);
y_ptr
[
t
+
4
*
BS
]
=
Tdata
(
s4
*
Tcompute
(
w_ptr
[
t
+
4
*
BS
])
*
rms
);
y_ptr
[
t
+
5
*
BS
]
=
Tdata
(
s5
*
Tcompute
(
w_ptr
[
t
+
5
*
BS
])
*
rms
);
y_ptr
[
t
+
6
*
BS
]
=
Tdata
(
s6
*
Tcompute
(
w_ptr
[
t
+
6
*
BS
])
*
rms
);
y_ptr
[
t
+
7
*
BS
]
=
Tdata
(
s7
*
Tcompute
(
w_ptr
[
t
+
7
*
BS
])
*
rms
);
}
#endif
//////////////////////////////////////////////////////////////////////
// #ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
// #define __ADD_RMS_NORM_CUDA_KERNEL_H__
// // 移除 cub 头文件依赖
// // #include <cub/block/block_reduce.cuh>
// template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
// __device__ void add_rmsnormBlock(
// Tdata * y, // 【修复 1】移除 __restrict__ 以支持 In-place
// Tdata * residual_out, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_y_batch,
// ptrdiff_t stride_y_nhead,
// ptrdiff_t stride_residual_out_batch,
// ptrdiff_t stride_residual_out_nhead,
// const Tdata * a, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_a_batch,
// ptrdiff_t stride_a_nhead,
// const Tdata * b, // 【修复 1】移除 __restrict__ 以支持 In-place
// ptrdiff_t stride_b_batch,
// ptrdiff_t stride_b_nhead,
// const Tweight *__restrict__ w, // 权重不被修改,保留 __restrict__ 是安全的
// size_t nhead,
// size_t dim,
// float epsilon) {
// size_t batch_idx = blockIdx.x / nhead;
// size_t head_idx = blockIdx.x % nhead;
// auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
// auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead;
// auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead;
// auto w_ptr = w;
// Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead;
// Tcompute sum_squared = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
// residual_out_ptr[i] = Tdata(sum_val); // Store add result
// sum_squared += sum_val * sum_val;
// }
// // 【修复 2】使用通用且安全的 Shared Memory 手动规约替换 cub::BlockReduce
// // 这样不会受制于特定设备的 Warp Size 差异导致死锁
// __shared__ Tcompute shared_sum[BLOCK_SIZE];
// shared_sum[threadIdx.x] = sum_squared;
// __syncthreads();
// #pragma unroll
// for (unsigned int offset = BLOCK_SIZE / 2; offset > 0; offset /= 2) {
// if (threadIdx.x < offset) {
// shared_sum[threadIdx.x] += shared_sum[threadIdx.x + offset];
// }
// __syncthreads();
// }
// sum_squared = shared_sum[0];
// __shared__ Tcompute rms;
// if (threadIdx.x == 0) {
// rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
// }
// __syncthreads();
// // 重新利用算出的 residual_out
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(residual_out_ptr[i]);
// y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms);
// }
// }
// #endif
////////////////////////////////////////////////////////////////////////////
// #ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__
// #define __ADD_RMS_NORM_CUDA_KERNEL_H__
// #include <cub/block/block_reduce.cuh>
// // 假设每个线程最多处理的元素个数。
// // 例如 70B dim=8192, BLOCK_SIZE=1024,只需 8 个。设为 16 绝对够用。
// #define MAX_ELEMS_PER_THREAD 16
// template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
// __device__ void add_rmsnormBlock(
// Tdata *__restrict__ y,
// Tdata *__restrict__ residual_out,
// ptrdiff_t stride_y_batch,
// ptrdiff_t stride_y_seq, // 🌟 修正命名:通常是按 seq_len 划分,而不是 nhead
// ptrdiff_t stride_residual_out_batch,
// ptrdiff_t stride_residual_out_seq,
// const Tdata *__restrict__ a,
// ptrdiff_t stride_a_batch,
// ptrdiff_t stride_a_seq,
// const Tdata *__restrict__ b,
// ptrdiff_t stride_b_batch,
// ptrdiff_t stride_b_seq,
// const Tweight *__restrict__ w,
// size_t seq_len, // 🌟 修正命名:取代 nhead
// size_t dim,
// float epsilon) {
// // 🌟 一个 Block 处理一个 Token
// size_t batch_idx = blockIdx.x / seq_len;
// size_t seq_idx = blockIdx.x % seq_len;
// auto y_ptr = y + batch_idx * stride_y_batch + seq_idx * stride_y_seq;
// auto a_ptr = a + batch_idx * stride_a_batch + seq_idx * stride_a_seq;
// auto b_ptr = b + batch_idx * stride_b_batch + seq_idx * stride_b_seq;
// Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + seq_idx * stride_residual_out_seq;
// Tcompute sum_squared = 0;
// // 🌟 真融合核心:用寄存器数组缓存当前线程计算的加法结果!
// Tcompute thread_cache[MAX_ELEMS_PER_THREAD];
// int cache_idx = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]);
// residual_out_ptr[i] = Tdata(sum_val); // 依然写回全局显存供后续 Attention 使用
// thread_cache[cache_idx++] = sum_val; // 🌟 同时保存在极速寄存器中!
// sum_squared += sum_val * sum_val;
// }
// // Block 内规约求平方和
// using BlockReduce = cub::BlockReduce<Tcompute, BLOCK_SIZE>;
// __shared__ typename BlockReduce::TempStorage temp_storage;
// sum_squared = BlockReduce(temp_storage).Sum(sum_squared);
// __shared__ Tcompute rms;
// if (threadIdx.x == 0) {
// rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon));
// }
// __syncthreads();
// // 🌟 第二阶段:直接从寄存器 `thread_cache` 读取,彻底干掉那次致命的显存读取!
// cache_idx = 0;
// for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) {
// // 使用 __ldg (如果框架支持) 读取公共权重,速度拉满
// Tcompute weight_val = Tcompute(__ldg(&w[i]));
// y_ptr[i] = Tdata(thread_cache[cache_idx++] * weight_val * rms);
// }
// }
// #endif
\ No newline at end of file
src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu
View file @
5de45ee6
...
...
@@ -8,6 +8,217 @@
#include "../cuda/kernel.cuh"
// DIM=4096, block=1024, BF16: nv_bfloat162 + float regs; pair idx = tid + i*1024 (same reduction order as scalar fast path).
// (Contiguous longlong2 tiling changed per-thread partial sums order vs CUB reduce and broke bit-level match with reference runs.)
__device__
void
add_rmsnormBlock_dim4096_bs1024_bf162_vec
(
__nv_bfloat16
*
__restrict__
y
,
__nv_bfloat16
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
__nv_bfloat16
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
__nv_bfloat16
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
__nv_bfloat16
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
constexpr
unsigned
int
BS
=
1024
;
constexpr
float
DIM_F
=
4096.0
f
;
const
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
const
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
const
__nv_bfloat16
*
a_base
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
const
__nv_bfloat16
*
b_base
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
__nv_bfloat16
*
res_base
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
__nv_bfloat16
*
y_base
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
const
__nv_bfloat162
*
a_ptr2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
a_base
);
const
__nv_bfloat162
*
b_ptr2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
b_base
);
const
__nv_bfloat162
*
w_ptr2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
w
);
__nv_bfloat162
*
res_ptr2
=
reinterpret_cast
<
__nv_bfloat162
*>
(
res_base
);
__nv_bfloat162
*
y_ptr2
=
reinterpret_cast
<
__nv_bfloat162
*>
(
y_base
);
float
sum_squared
=
0.0
f
;
float
s1_reg
[
2
];
float
s2_reg
[
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
const
int
idx
=
static_cast
<
int
>
(
threadIdx
.
x
)
+
i
*
static_cast
<
int
>
(
BS
);
const
__nv_bfloat162
val_a
=
a_ptr2
[
idx
];
const
__nv_bfloat162
val_b
=
b_ptr2
[
idx
];
const
float
f_a1
=
__low2float
(
val_a
);
const
float
f_a2
=
__high2float
(
val_a
);
const
float
f_b1
=
__low2float
(
val_b
);
const
float
f_b2
=
__high2float
(
val_b
);
const
float
t1
=
f_a1
+
f_b1
;
const
float
t2
=
f_a2
+
f_b2
;
s1_reg
[
i
]
=
t1
;
s2_reg
[
i
]
=
t2
;
res_ptr2
[
idx
]
=
__floats2bfloat162_rn
(
t1
,
t2
);
sum_squared
+=
t1
*
t1
+
t2
*
t2
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
BS
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
__shared__
float
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
rsqrtf
(
sum_squared
/
DIM_F
+
epsilon
);
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
const
int
idx
=
static_cast
<
int
>
(
threadIdx
.
x
)
+
i
*
static_cast
<
int
>
(
BS
);
const
__nv_bfloat162
val_w
=
w_ptr2
[
idx
];
const
float
f_w1
=
__low2float
(
val_w
);
const
float
f_w2
=
__high2float
(
val_w
);
const
float
y1
=
s1_reg
[
i
]
*
f_w1
*
rms
;
const
float
y2
=
s2_reg
[
i
]
*
f_w2
*
rms
;
y_ptr2
[
idx
]
=
__floats2bfloat162_rn
(
y1
,
y2
);
}
}
INFINIOP_CUDA_KERNEL
add_rmsnormKernel_dim4096_bs1024_bf162_vec
(
__nv_bfloat16
*
__restrict__
y
,
__nv_bfloat16
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
__nv_bfloat16
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
__nv_bfloat16
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
__nv_bfloat16
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
add_rmsnormBlock_dim4096_bs1024_bf162_vec
(
y
,
residual_out
,
stride_y_batch
,
stride_y_nhead
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
w
,
nhead
,
epsilon
);
}
// DIM=8192, block=1024: 4x nv_bfloat162 per thread; pair idx = tid + i*1024 (same as scalar tiling; avoids longlong2 reorder issues).
__device__
void
add_rmsnormBlock_dim8192_bs1024_bf162_vec
(
__nv_bfloat16
*
__restrict__
y
,
__nv_bfloat16
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
__nv_bfloat16
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
__nv_bfloat16
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
__nv_bfloat16
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
constexpr
unsigned
int
BS
=
1024
;
constexpr
float
DIM_F
=
8192.0
f
;
const
size_t
batch_idx
=
blockIdx
.
x
/
nhead
;
const
size_t
head_idx
=
blockIdx
.
x
%
nhead
;
const
__nv_bfloat16
*
a_base
=
a
+
batch_idx
*
stride_a_batch
+
head_idx
*
stride_a_nhead
;
const
__nv_bfloat16
*
b_base
=
b
+
batch_idx
*
stride_b_batch
+
head_idx
*
stride_b_nhead
;
__nv_bfloat16
*
res_base
=
residual_out
+
batch_idx
*
stride_residual_out_batch
+
head_idx
*
stride_residual_out_nhead
;
__nv_bfloat16
*
y_base
=
y
+
batch_idx
*
stride_y_batch
+
head_idx
*
stride_y_nhead
;
const
__nv_bfloat162
*
a_ptr2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
a_base
);
const
__nv_bfloat162
*
b_ptr2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
b_base
);
const
__nv_bfloat162
*
w_ptr2
=
reinterpret_cast
<
const
__nv_bfloat162
*>
(
w
);
__nv_bfloat162
*
res_ptr2
=
reinterpret_cast
<
__nv_bfloat162
*>
(
res_base
);
__nv_bfloat162
*
y_ptr2
=
reinterpret_cast
<
__nv_bfloat162
*>
(
y_base
);
float
sum_squared
=
0.0
f
;
float
s1_reg
[
4
];
float
s2_reg
[
4
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
const
int
idx
=
static_cast
<
int
>
(
threadIdx
.
x
)
+
i
*
static_cast
<
int
>
(
BS
);
const
__nv_bfloat162
val_a
=
a_ptr2
[
idx
];
const
__nv_bfloat162
val_b
=
b_ptr2
[
idx
];
const
float
f_a1
=
__low2float
(
val_a
);
const
float
f_a2
=
__high2float
(
val_a
);
const
float
f_b1
=
__low2float
(
val_b
);
const
float
f_b2
=
__high2float
(
val_b
);
const
float
t1
=
f_a1
+
f_b1
;
const
float
t2
=
f_a2
+
f_b2
;
s1_reg
[
i
]
=
t1
;
s2_reg
[
i
]
=
t2
;
res_ptr2
[
idx
]
=
__floats2bfloat162_rn
(
t1
,
t2
);
sum_squared
+=
t1
*
t1
+
t2
*
t2
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
BS
>
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
sum_squared
=
BlockReduce
(
temp_storage
).
Sum
(
sum_squared
);
__shared__
float
rms
;
if
(
threadIdx
.
x
==
0
)
{
rms
=
rsqrtf
(
sum_squared
/
DIM_F
+
epsilon
);
}
__syncthreads
();
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
const
int
idx
=
static_cast
<
int
>
(
threadIdx
.
x
)
+
i
*
static_cast
<
int
>
(
BS
);
const
__nv_bfloat162
val_w
=
w_ptr2
[
idx
];
const
float
f_w1
=
__low2float
(
val_w
);
const
float
f_w2
=
__high2float
(
val_w
);
const
float
y1
=
s1_reg
[
i
]
*
f_w1
*
rms
;
const
float
y2
=
s2_reg
[
i
]
*
f_w2
*
rms
;
y_ptr2
[
idx
]
=
__floats2bfloat162_rn
(
y1
,
y2
);
}
}
INFINIOP_CUDA_KERNEL
add_rmsnormKernel_dim8192_bs1024_bf162_vec
(
__nv_bfloat16
*
__restrict__
y
,
__nv_bfloat16
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
__nv_bfloat16
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
__nv_bfloat16
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
__nv_bfloat16
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
add_rmsnormBlock_dim8192_bs1024_bf162_vec
(
y
,
residual_out
,
stride_y_batch
,
stride_y_nhead
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
w
,
nhead
,
epsilon
);
}
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
INFINIOP_CUDA_KERNEL
add_rmsnormKernel
(
Tdata
*
__restrict__
y
,
...
...
@@ -35,6 +246,58 @@ INFINIOP_CUDA_KERNEL add_rmsnormKernel(
w
,
nhead
,
dim
,
epsilon
);
}
template
<
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
INFINIOP_CUDA_KERNEL
add_rmsnormKernel_dim4096_bs1024
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
add_rmsnormBlock_dim4096_bs1024
<
Tcompute
,
Tdata
,
Tweight
>
(
y
,
residual_out
,
stride_y_batch
,
stride_y_nhead
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
w
,
nhead
,
epsilon
);
}
template
<
typename
Tcompute
,
typename
Tdata
,
typename
Tweight
>
INFINIOP_CUDA_KERNEL
add_rmsnormKernel_dim8192_bs1024
(
Tdata
*
__restrict__
y
,
Tdata
*
__restrict__
residual_out
,
ptrdiff_t
stride_y_batch
,
ptrdiff_t
stride_y_nhead
,
ptrdiff_t
stride_residual_out_batch
,
ptrdiff_t
stride_residual_out_nhead
,
const
Tdata
*
__restrict__
a
,
ptrdiff_t
stride_a_batch
,
ptrdiff_t
stride_a_nhead
,
const
Tdata
*
__restrict__
b
,
ptrdiff_t
stride_b_batch
,
ptrdiff_t
stride_b_nhead
,
const
Tweight
*
__restrict__
w
,
size_t
nhead
,
float
epsilon
)
{
add_rmsnormBlock_dim8192_bs1024
<
Tcompute
,
Tdata
,
Tweight
>
(
y
,
residual_out
,
stride_y_batch
,
stride_y_nhead
,
stride_residual_out_batch
,
stride_residual_out_nhead
,
a
,
stride_a_batch
,
stride_a_nhead
,
b
,
stride_b_batch
,
stride_b_nhead
,
w
,
nhead
,
epsilon
);
}
namespace
op
::
add_rms_norm
::
nvidia
{
struct
Descriptor
::
Opaque
{
...
...
@@ -97,7 +360,115 @@ infiniStatus_t launchKernel(
dim, \
epsilon)
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
#define LAUNCH_KERNEL_DIM4096_BS1024(Tdata, Tweight, Tcompute) \
add_rmsnormKernel_dim4096_bs1024<Tcompute, Tdata, Tweight><<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
epsilon)
#define LAUNCH_KERNEL_DIM4096_BS1024_BF162_VEC \
add_rmsnormKernel_dim4096_bs1024_bf162_vec<<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<__nv_bfloat16 *>(y), \
reinterpret_cast<__nv_bfloat16 *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(w), \
nhead, \
epsilon)
#define LAUNCH_KERNEL_DIM8192_BS1024(Tdata, Tweight, Tcompute) \
add_rmsnormKernel_dim8192_bs1024<Tcompute, Tdata, Tweight><<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
reinterpret_cast<Tdata *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const Tdata *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const Tdata *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
epsilon)
#define LAUNCH_KERNEL_DIM8192_BS1024_BF162_VEC \
add_rmsnormKernel_dim8192_bs1024_bf162_vec<<<batch_size * nhead, CUDA_BLOCK_SIZE_1024, 0, cuda_stream>>>( \
reinterpret_cast<__nv_bfloat16 *>(y), \
reinterpret_cast<__nv_bfloat16 *>(residual_out), \
stride_y_batch, \
stride_y_nhead, \
stride_residual_out_batch, \
stride_residual_out_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(a), \
stride_a_batch, \
stride_a_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(b), \
stride_b_batch, \
stride_b_nhead, \
reinterpret_cast<const __nv_bfloat16 *>(w), \
nhead, \
epsilon)
if
(
dim
==
4096
&&
BLOCK_SIZE
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL_DIM4096_BS1024
(
half
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL_DIM4096_BS1024
(
half
,
__nv_bfloat16
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL_DIM4096_BS1024
(
half
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL_DIM4096_BS1024_BF162_VEC
;
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL_DIM4096_BS1024
(
__nv_bfloat16
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL_DIM4096_BS1024
(
__nv_bfloat16
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL_DIM4096_BS1024
(
float
,
float
,
float
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
dim
==
8192
&&
BLOCK_SIZE
==
CUDA_BLOCK_SIZE_1024
)
{
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL_DIM8192_BS1024
(
half
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL_DIM8192_BS1024
(
half
,
__nv_bfloat16
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL_DIM8192_BS1024
(
half
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL_DIM8192_BS1024_BF162_VEC
;
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL_DIM8192_BS1024
(
__nv_bfloat16
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_BF16
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL_DIM8192_BS1024
(
__nv_bfloat16
,
float
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F32
&&
wtype
==
INFINI_DTYPE_F32
)
{
LAUNCH_KERNEL_DIM8192_BS1024
(
float
,
float
,
float
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_F16
)
{
LAUNCH_KERNEL
(
half
,
half
,
float
);
}
else
if
(
atype
==
INFINI_DTYPE_F16
&&
wtype
==
INFINI_DTYPE_BF16
)
{
LAUNCH_KERNEL
(
half
,
__nv_bfloat16
,
float
);
...
...
@@ -115,6 +486,10 @@ infiniStatus_t launchKernel(
return
INFINI_STATUS_BAD_TENSOR_DTYPE
;
}
#undef LAUNCH_KERNEL_DIM8192_BS1024_BF162_VEC
#undef LAUNCH_KERNEL_DIM8192_BS1024
#undef LAUNCH_KERNEL_DIM4096_BS1024_BF162_VEC
#undef LAUNCH_KERNEL_DIM4096_BS1024
#undef LAUNCH_KERNEL
return
INFINI_STATUS_SUCCESS
;
...
...
src/infiniop/ops/paged_attention/cuda/kernel_v2.cuh
View file @
5de45ee6
...
...
@@ -79,7 +79,8 @@ __device__ __forceinline__ float warpReduceMax(float x) {
}
__device__
__forceinline__
unsigned
int
cvtaToShared
(
const
void
*
ptr
)
{
#if defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API)
// Iluvatar and Hygon DCU (HIP): use raw pointer cast instead of CUDA intrinsic.
return
static_cast
<
unsigned
int
>
(
reinterpret_cast
<
uintptr_t
>
(
ptr
));
#else
return
static_cast
<
unsigned
int
>
(
__cvta_generic_to_shared
(
ptr
));
...
...
@@ -87,7 +88,8 @@ __device__ __forceinline__ unsigned int cvtaToShared(const void *ptr) {
}
__device__
__forceinline__
void
cpAsyncCaSharedGlobal16
(
void
*
dst_shared
,
const
void
*
src_global
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// cp.async is NVIDIA PTX-only; Hygon DCU (HIP) must use plain loads instead.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && !defined(ENABLE_HYGON_API)
const
unsigned
int
dst
=
cvtaToShared
(
dst_shared
);
asm
volatile
(
"cp.async.ca.shared.global [%0], [%1], 16;
\n
"
::
"r"
(
dst
),
"l"
(
src_global
));
#else
...
...
@@ -98,14 +100,14 @@ __device__ __forceinline__ void cpAsyncCaSharedGlobal16(void *dst_shared, const
}
__device__
__forceinline__
void
cpAsyncCommit
()
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
&& !defined(ENABLE_HYGON_API)
asm
volatile
(
"cp.async.commit_group;
\n
"
::
);
#endif
}
template
<
int
N
>
__device__
__forceinline__
void
cpAsyncWaitGroup
()
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
&& !defined(ENABLE_HYGON_API)
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
N
));
#endif
}
...
...
@@ -113,7 +115,7 @@ __device__ __forceinline__ void cpAsyncWaitGroup() {
// cp.async.wait_group requires a compile-time immediate, so for small fixed
// stage counts we provide a tiny runtime switch.
__device__
__forceinline__
void
cpAsyncWaitGroupRt
(
int
n
)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
&& !defined(ENABLE_HYGON_API)
if
(
n
<=
0
)
{
cpAsyncWaitGroup
<
0
>
();
}
else
if
(
n
==
1
)
{
...
...
@@ -1143,8 +1145,7 @@ __device__ void flashAttentionDecodeCtaPipelinedKernel(
// Prefetch the very first token.
int
buf
=
0
;
int
t_base
=
0
;
int
token_in_block
=
0
;
(
void
)
0
;
// t_base, token_in_block removed (unused)
int
logical_block
=
0
;
{
if
(
tid
==
0
)
{
...
...
src/infiniop/ops/paged_attention/operator.cc
View file @
5de45ee6
...
...
@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/paged_attention.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
|| defined(ENABLE_HYGON_API)
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_MOORE_API
...
...
@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#endif
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -112,6 +118,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttention(
#endif
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -141,6 +150,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#endif
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
View file @
5de45ee6
...
...
@@ -2306,9 +2306,10 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
}
__syncthreads
();
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
&& !defined(ENABLE_HYGON_API)
// WMMA: each warp computes scores for 16 keys (one 16-column slice of the K tile) across all 16 rows.
// For kBlockN=64, only the first 4 warps participate in WMMA score computation.
// nvcuda::wmma is NVIDIA-only; HIP/ROCm does not support it.
namespace
wmma
=
nvcuda
::
wmma
;
constexpr
int
kNSub
=
kBlockN
/
16
;
if
(
warp_id
<
kNSub
)
{
...
...
src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu
View file @
5de45ee6
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
|| defined(ENABLE_HYGON_API)
#include <cuda_runtime.h>
#include <cstdint>
...
...
src/infiniop/ops/paged_attention_prefill/operator.cc
View file @
5de45ee6
...
...
@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/paged_attention_prefill.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
|| defined(ENABLE_HYGON_API)
#include "nvidia/paged_attention_prefill_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -48,6 +48,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
@@ -78,6 +81,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
@@ -115,6 +121,9 @@ __INFINI_C infiniStatus_t infiniopPagedAttentionPrefill(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
@@ -144,6 +153,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
View file @
5de45ee6
...
...
@@ -94,6 +94,21 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
k_cache_slot_stride
,
v_cache_slot_stride
);
}
else
if
(
dtype
==
INFINI_DTYPE_BF16
)
{
std
::
cout
<<
"NUM_THREADS: "
<<
NUM_THREADS
<<
std
::
endl
;
std
::
cout
<<
"grid: "
<<
grid
.
x
<<
", "
<<
grid
.
y
<<
", "
<<
grid
.
z
<<
std
::
endl
;
std
::
cout
<<
"block: "
<<
block
.
x
<<
", "
<<
block
.
y
<<
", "
<<
block
.
z
<<
std
::
endl
;
std
::
cout
<<
"shared_mem_size: "
<<
shared_mem_size
<<
std
::
endl
;
std
::
cout
<<
"slot_mapping: "
<<
slot_mapping
<<
std
::
endl
;
std
::
cout
<<
"head_size: "
<<
head_size
<<
std
::
endl
;
std
::
cout
<<
"block_size: "
<<
block_size
<<
std
::
endl
;
std
::
cout
<<
"k_src_stride: "
<<
k_src_stride
<<
std
::
endl
;
std
::
cout
<<
"v_src_stride: "
<<
v_src_stride
<<
std
::
endl
;
std
::
cout
<<
"k_cache_block_stride: "
<<
k_cache_block_stride
<<
std
::
endl
;
std
::
cout
<<
"v_cache_block_stride: "
<<
v_cache_block_stride
<<
std
::
endl
;
std
::
cout
<<
"k_cache_head_stride: "
<<
k_cache_head_stride
<<
std
::
endl
;
std
::
cout
<<
"v_cache_head_stride: "
<<
v_cache_head_stride
<<
std
::
endl
;
std
::
cout
<<
"k_cache_slot_stride: "
<<
k_cache_slot_stride
<<
std
::
endl
;
std
::
cout
<<
"v_cache_slot_stride: "
<<
v_cache_slot_stride
<<
std
::
endl
;
pagedCaching
<
__nv_bfloat16
,
NUM_THREADS
>
<<<
grid
,
block
,
shared_mem_size
,
stream
>>>
(
(
__nv_bfloat16
*
)
k_cache
,
...
...
src/infiniop/ops/paged_caching/operator.cc
View file @
5de45ee6
...
...
@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/paged_caching.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ALI_API) || defined(ENABLE_ILUVATAR_API)
|| defined(ENABLE_HYGON_API)
#include "nvidia/paged_caching_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -41,6 +41,9 @@ __INFINI_C infiniStatus_t infiniopCreatePagedCachingDescriptor(
#ifdef ENABLE_ILUVATAR_API
CREATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
@@ -71,6 +74,9 @@ __INFINI_C infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
#ifdef ENABLE_ILUVATAR_API
GET
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
@@ -105,6 +111,9 @@ __INFINI_C infiniStatus_t infiniopPagedCaching(
#ifdef ENABLE_ILUVATAR_API
CALCULATE
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
@@ -134,6 +143,9 @@ __INFINI_C infiniStatus_t infiniopDestroyPagedCachingDescriptor(
#ifdef ENABLE_ILUVATAR_API
DESTROY
(
INFINI_DEVICE_ILUVATAR
,
nvidia
)
#endif
#ifdef ENABLE_HYGON_API
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
)
#endif
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
)
#endif
...
...
src/infiniop/ops/rope/cuda/kernel.cuh
View file @
5de45ee6
...
...
@@ -103,4 +103,50 @@ __device__ void ropeThreadPerItemBlock(
}
}
// grid_dim = dim3(info.seqlen, info.batch, 1);
// dim3 block_dim = dim3(info.table_dim, info.nhead, 1);
template
<
bool
IsGPTJ
,
typename
Tindex
,
typename
Tangle
>
__device__
void
customropeThreadPerItemBlock
(
cuda_bfloat16
*
y_
,
const
cuda_bfloat16
*
x_
,
const
Tindex
*
__restrict__
pos_ids
,
const
Tangle
*
__restrict__
sin_table
,
const
Tangle
*
__restrict__
cos_table
,
size_t
table_dim
,
size_t
pos_stride_batch
,
// Stride for batch dimension in pos_ids (0 if 1D)
bool
pos_has_batch_dim
,
// Whether pos_ids has batch dimension
bool
has_batch_dim
,
// Whether tensors have batch dimension
ptrdiff_t
y_stride_batch
,
ptrdiff_t
y_stride_seqlen
,
ptrdiff_t
y_stride_nhead
,
ptrdiff_t
x_stride_batch
,
ptrdiff_t
x_stride_seqlen
,
ptrdiff_t
x_stride_nhead
)
{
const
size_t
batch_idx
=
blockIdx
.
y
;
const
size_t
seq_idx
=
blockIdx
.
x
;
const
size_t
head_idx
=
threadIdx
.
y
;
const
size_t
dim_idx
=
threadIdx
.
x
;
auto
y_offset
=
batch_idx
*
y_stride_batch
+
seq_idx
*
y_stride_seqlen
+
head_idx
*
y_stride_nhead
;
auto
x_offset
=
batch_idx
*
x_stride_batch
+
seq_idx
*
x_stride_seqlen
+
head_idx
*
x_stride_nhead
;
size_t
pos_offset
=
batch_idx
*
pos_stride_batch
+
seq_idx
;
size_t
pos_id
=
size_t
(
pos_ids
[
pos_offset
]);
auto
table_offset
=
pos_id
*
table_dim
;
Tangle
sin__
=
sin_table
[
table_offset
+
dim_idx
];
Tangle
cos__
=
cos_table
[
table_offset
+
dim_idx
];
size_t
pos0
=
dim_idx
;
size_t
pos1
=
dim_idx
+
table_dim
;
Tangle
x0
=
__bfloat162float
(
x_
[
x_offset
+
pos0
]);
Tangle
x1
=
__bfloat162float
(
x_
[
x_offset
+
pos1
]);
Tangle
y0
=
x0
*
cos__
-
x1
*
sin__
;
Tangle
y1
=
x0
*
sin__
+
x1
*
cos__
;
y_
[
y_offset
+
pos0
]
=
__float2bfloat16
(
y0
);
y_
[
y_offset
+
pos1
]
=
__float2bfloat16
(
y1
);
}
#endif
src/infiniop/ops/rope/nvidia/rope_nvidia.cu
View file @
5de45ee6
...
...
@@ -33,6 +33,34 @@ INFINIOP_CUDA_KERNEL ropeThreadPerItemKernel(
x_stride_batch
,
x_stride_seqlen
,
x_stride_nhead
);
}
template
<
bool
IsGPTJ
,
typename
Tindex
,
typename
Tangle
>
INFINIOP_CUDA_KERNEL
customropeThreadPerItemKernel
(
cuda_bfloat16
*
y_
,
const
cuda_bfloat16
*
x_
,
const
Tindex
*
__restrict__
pos_ids
,
const
Tangle
*
__restrict__
sin_table
,
const
Tangle
*
__restrict__
cos_table
,
size_t
table_dim
,
size_t
pos_stride_batch
,
// Stride for batch dimension in pos_ids
bool
pos_has_batch_dim
,
// Whether pos_ids has batch dimension
bool
has_batch_dim
,
// Whether tensors have batch dimension
ptrdiff_t
y_stride_batch
,
ptrdiff_t
y_stride_seqlen
,
ptrdiff_t
y_stride_nhead
,
ptrdiff_t
x_stride_batch
,
ptrdiff_t
x_stride_seqlen
,
ptrdiff_t
x_stride_nhead
)
{
customropeThreadPerItemBlock
<
IsGPTJ
>
(
y_
,
x_
,
pos_ids
,
sin_table
,
cos_table
,
table_dim
,
pos_stride_batch
,
pos_has_batch_dim
,
has_batch_dim
,
y_stride_batch
,
y_stride_seqlen
,
y_stride_nhead
,
x_stride_batch
,
x_stride_seqlen
,
x_stride_nhead
);
}
namespace
op
::
rope
::
nvidia
{
struct
Descriptor
::
Opaque
{
...
...
@@ -96,9 +124,15 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
grid_dim
=
dim3
(
dimx
,
dimy
,
dimz
);
}
else
{
// 3D tensors: use 2D grid [seqlen, nhead], batch dimension is 1
grid_dim
=
dim3
(
dimx
,
dimy
);
grid_dim
=
dim3
(
dimx
,
dimy
,
1
);
}
// printf("block_size = %d info.table_dim = %ld has_batch_dim: %d, is_gpt_j: %d pos_has_batch_dim: %d\n",
// block_size, info.table_dim, info.has_batch_dim, is_gpt_j, info.pos_has_batch_dim);
// [batch, seqlen, nhead, dhead, table_len, table_dim, y_stride_batch, y_stride_seqlen, y_stride_nhead, x_stride_batch, x_stride_seqlen,x_stride_nhead]
// printf("[%ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld %ld]\n", info.batch,
// info.seqlen, info.nhead, info.dhead, info.table_len, info.table_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
if
(
is_gpt_j
)
{
ropeThreadPerItemKernel
<
true
><<<
grid_dim
,
nthreads
,
0
,
stream
>>>
(
y
,
x
,
pos_ids
,
sin_table
,
cos_table
,
info
.
table_dim
,
...
...
@@ -108,13 +142,35 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
info
.
y_stride_batch
,
info
.
y_stride_seqlen
,
info
.
y_stride_nhead
,
info
.
x_stride_batch
,
info
.
x_stride_seqlen
,
info
.
x_stride_nhead
);
}
else
{
ropeThreadPerItemKernel
<
false
><<<
grid_dim
,
nthreads
,
0
,
stream
>>>
(
y
,
x
,
pos_ids
,
sin_table
,
cos_table
,
info
.
table_dim
,
pos_stride_batch
,
info
.
pos_has_batch_dim
,
info
.
has_batch_dim
,
info
.
y_stride_batch
,
info
.
y_stride_seqlen
,
info
.
y_stride_nhead
,
info
.
x_stride_batch
,
info
.
x_stride_seqlen
,
info
.
x_stride_nhead
);
if
((
std
::
is_same
<
Tdata
,
cuda_bfloat16
>::
value
)
&&
(
info
.
table_dim
==
64
)
&&
(
info
.
nhead
<
16
)
&&
(
info
.
seqlen
<
505
))
{
auto
bf16_y
=
reinterpret_cast
<
cuda_bfloat16
*>
(
y
);
auto
bf16_x
=
reinterpret_cast
<
const
cuda_bfloat16
*>
(
x
);
grid_dim
=
dim3
(
info
.
seqlen
,
info
.
batch
,
1
);
dim3
block_dim
=
dim3
(
info
.
table_dim
,
info
.
nhead
,
1
);
customropeThreadPerItemKernel
<
false
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
bf16_y
,
bf16_x
,
pos_ids
,
sin_table
,
cos_table
,
info
.
table_dim
,
pos_stride_batch
,
info
.
pos_has_batch_dim
,
info
.
has_batch_dim
,
info
.
y_stride_batch
,
info
.
y_stride_seqlen
,
info
.
y_stride_nhead
,
info
.
x_stride_batch
,
info
.
x_stride_seqlen
,
info
.
x_stride_nhead
);
// ropeThreadPerItemKernel<false><<<grid_dim, nthreads, 0, stream>>>(
// y, x, pos_ids, sin_table, cos_table, info.table_dim,
// pos_stride_batch,
// info.pos_has_batch_dim,
// info.has_batch_dim,
// info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead,
// info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead);
}
else
{
ropeThreadPerItemKernel
<
false
><<<
grid_dim
,
nthreads
,
0
,
stream
>>>
(
y
,
x
,
pos_ids
,
sin_table
,
cos_table
,
info
.
table_dim
,
pos_stride_batch
,
info
.
pos_has_batch_dim
,
info
.
has_batch_dim
,
info
.
y_stride_batch
,
info
.
y_stride_seqlen
,
info
.
y_stride_nhead
,
info
.
x_stride_batch
,
info
.
x_stride_seqlen
,
info
.
x_stride_nhead
);
}
}
return
INFINI_STATUS_SUCCESS
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment