Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
6a59259c
Commit
6a59259c
authored
Apr 25, 2026
by
zhangyue
Browse files
自定义allreduce初版
parent
71cac971
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1503 additions
and
9 deletions
+1503
-9
include/infiniccl.h
include/infiniccl.h
+32
-0
src/infiniccl/ascend/infiniccl_ascend.cc
src/infiniccl/ascend/infiniccl_ascend.cc
+2
-1
src/infiniccl/cambricon/infiniccl_cambricon.cc
src/infiniccl/cambricon/infiniccl_cambricon.cc
+2
-1
src/infiniccl/cuda/infiniccl_cuda.cu
src/infiniccl/cuda/infiniccl_cuda.cu
+550
-3
src/infiniccl/cuda/infiniccl_custom_all_reduce.cuh
src/infiniccl/cuda/infiniccl_custom_all_reduce.cuh
+667
-0
src/infiniccl/infiniccl.cc
src/infiniccl/infiniccl.cc
+23
-1
src/infiniccl/infiniccl_impl.h
src/infiniccl/infiniccl_impl.h
+14
-0
src/infiniccl/kunlun/infiniccl_kunlun.cc
src/infiniccl/kunlun/infiniccl_kunlun.cc
+2
-1
src/infiniccl/metax/infiniccl_metax.cc
src/infiniccl/metax/infiniccl_metax.cc
+2
-1
src/infiniccl/moore/infiniccl_moore.cc
src/infiniccl/moore/infiniccl_moore.cc
+2
-1
src/infinicore/ops/mha_kvcache/mha_kvcache_hygon_paged.cc
src/infinicore/ops/mha_kvcache/mha_kvcache_hygon_paged.cc
+111
-0
src/infinicore/ops/multi_head_attention_varlen/mha_varlen_hygon_vllm.cc
.../ops/multi_head_attention_varlen/mha_varlen_hygon_vllm.cc
+96
-0
No files found.
include/infiniccl.h
View file @
6a59259c
...
...
@@ -15,6 +15,15 @@ struct InfinicclComm;
typedef
struct
InfinicclComm
*
infinicclComm_t
;
/**
* Initialize NCCL communicators (one per device). On Hygon DCU builds (ENABLE_HYGON_API), when
* device_type is INFINI_DEVICE_HYGON and ndevice is 2/4/6/8, also allocates per-GPU shared buffers
* (vLLM-style staging + Signal + rank_data) and wires infiniccl_ar::CustomAllreduce automatically;
* otherwise custom path stays disabled until infinicclCommSetHygonCustomAllreduce is used.
*
* Hygon switch: INFINICCL_CUSTOM_ALLREDUCE=0 or off disables that wiring; infinicclAllReduce then
* uses NCCL only for the same process (see infinicclAllReduce).
*/
__INFINI_C
__export
infiniStatus_t
infinicclCommInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
...
...
@@ -23,6 +32,29 @@ __INFINI_C __export infiniStatus_t infinicclCommInitAll(
__INFINI_C
__export
infiniStatus_t
infinicclCommDestroy
(
infinicclComm_t
comm
);
/**
* Hygon DCU only: attach an optional custom allreduce (opaque infiniccl_ar::CustomAllreduce*).
* Other device types receive INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED.
* When set, infinicclAllReduce may use it for SUM on f32/f16/bf16 payloads up to 8192 * 1024 bytes
* (8 MiB); larger or unsupported cases use NCCL.
*
* If reg_buffer is non-null, sendbuf is copied to reg_buffer on the same stream before the custom
* kernel (vLLM-style): fixed IPC-registered buffer for CUDA graph or unregistered sendbuf.
* reg_buffer_bytes must be >= payload when reg_buffer is used. Pass custom_allreduce == nullptr to clear.
* Do not call this after commInitAll has already auto-wired Hygon custom allreduce (returns BAD_PARAM).
*/
__INFINI_C
__export
infiniStatus_t
infinicclCommSetHygonCustomAllreduce
(
infinicclComm_t
comm
,
void
*
custom_allreduce
,
void
*
reg_buffer
,
size_t
reg_buffer_bytes
);
/**
* Hygon: optional custom allreduce for small SUM payloads (see comm init / setHygon).
* Runtime switch: INFINICCL_CUSTOM_ALLREDUCE=0 or off forces NCCL even if custom objects were initialized.
* Diagnostics to stderr: INFINICCL_CUSTOM_ALLREDUCE_DEBUG=1 (coarse path hints);
* INFINICCL_CUSTOM_ALLREDUCE_TRACE=1 (first 128 custom kernel invocations and up to 48 NCCL fallbacks after try_custom, per OS process).
*/
__INFINI_C
__export
infiniStatus_t
infinicclAllReduce
(
void
*
sendbuf
,
void
*
recvbuf
,
...
...
src/infiniccl/ascend/infiniccl_ascend.cc
View file @
6a59259c
...
...
@@ -55,6 +55,7 @@ inline HcclReduceOp getHcclRedOp(infinicclReduceOp_t op) {
namespace
infiniccl
::
ascend
{
infiniStatus_t
commInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
int
ndevice
,
const
int
*
device_ids
)
{
...
...
@@ -67,7 +68,7 @@ infiniStatus_t commInitAll(
CHECK_HCCL
(
HcclCommInitAll
(
ndevice
,
(
int32_t
*
)
device_ids
,
hccl_comms
.
data
()));
for
(
int
i
=
0
;
i
<
ndevice
;
i
++
)
{
comms
[
i
]
=
new
InfinicclComm
{
INFINI_DEVICE_ASCEND
,
device_ids
[
i
],
(
void
*
)(
hccl_comms
[
i
])};
comms
[
i
]
=
new
InfinicclComm
{
device_type
,
device_ids
[
i
],
(
void
*
)(
hccl_comms
[
i
])
,
nullptr
,
nullptr
,
0
,
nullptr
,
false
};
}
return
INFINI_STATUS_SUCCESS
;
...
...
src/infiniccl/cambricon/infiniccl_cambricon.cc
View file @
6a59259c
...
...
@@ -53,6 +53,7 @@ inline cnclReduceOp_t getCnclRedOp(infinicclReduceOp_t op) {
namespace
infiniccl
::
cambricon
{
infiniStatus_t
commInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
int
ndevice
,
const
int
*
device_ids
)
{
...
...
@@ -70,7 +71,7 @@ infiniStatus_t commInitAll(
ndevice
,
nullptr
));
for
(
int
i
=
0
;
i
<
ndevice
;
i
++
)
{
comms
[
i
]
=
new
InfinicclComm
{
INFINI_DEVICE_CAMBRICON
,
device_ids
[
i
],
(
void
*
)(
cncl_comms
[
i
])};
comms
[
i
]
=
new
InfinicclComm
{
device_type
,
device_ids
[
i
],
(
void
*
)(
cncl_comms
[
i
])
,
nullptr
,
nullptr
,
0
,
nullptr
,
false
};
}
return
INFINI_STATUS_SUCCESS
;
...
...
src/infiniccl/cuda/infiniccl_cuda.cu
View file @
6a59259c
This diff is collapsed.
Click to expand it.
src/infiniccl/cuda/infiniccl_custom_all_reduce.cuh
0 → 100644
View file @
6a59259c
This diff is collapsed.
Click to expand it.
src/infiniccl/infiniccl.cc
View file @
6a59259c
...
...
@@ -7,6 +7,11 @@
#include "./metax/infiniccl_metax.h"
#include "./moore/infiniccl_moore.h"
namespace
infiniccl
::
cuda
{
infiniStatus_t
commSetHygonCustomAllreduce
(
infinicclComm_t
comm
,
void
*
custom_allreduce
,
void
*
reg_buffer
,
size_t
reg_buffer_bytes
);
}
__INFINI_C
infiniStatus_t
infinicclCommInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
...
...
@@ -15,7 +20,7 @@ __INFINI_C infiniStatus_t infinicclCommInitAll(
#define COMM_INIT_ALL(CASE_, NAMESPACE_) \
case CASE_: \
return infiniccl::NAMESPACE_::commInitAll(comms, ndevice, device_ids)
return infiniccl::NAMESPACE_::commInitAll(
device_type,
comms, ndevice, device_ids)
switch
(
device_type
)
{
COMM_INIT_ALL
(
INFINI_DEVICE_NVIDIA
,
cuda
);
...
...
@@ -61,6 +66,23 @@ __INFINI_C infiniStatus_t infinicclCommDestroy(infinicclComm_t comm) {
#undef COMM_DESTROY
}
__INFINI_C
infiniStatus_t
infinicclCommSetHygonCustomAllreduce
(
infinicclComm_t
comm
,
void
*
custom_allreduce
,
void
*
reg_buffer
,
size_t
reg_buffer_bytes
)
{
if
(
comm
==
nullptr
)
{
return
INFINI_STATUS_NULL_POINTER
;
}
if
(
comm
->
device_type
!=
INFINI_DEVICE_HYGON
)
{
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
return
infiniccl
::
cuda
::
commSetHygonCustomAllreduce
(
comm
,
custom_allreduce
,
reg_buffer
,
reg_buffer_bytes
);
}
__INFINI_C
infiniStatus_t
infinicclAllReduce
(
void
*
sendbuf
,
void
*
recvbuf
,
...
...
src/infiniccl/infiniccl_impl.h
View file @
6a59259c
...
...
@@ -3,15 +3,29 @@
#include "infiniccl.h"
#include <cstddef>
struct
InfinicclComm
{
infiniDevice_t
device_type
;
int
device_id
;
// the actual device ID, not rank number
void
*
comm
;
// the actual communicator
/** Optional infiniccl_ar::CustomAllreduce* (Hygon DCU build only); nullptr disables hybrid path. */
void
*
custom_ar
;
/** Optional staging buffer: sendbuf is copied here before custom AR (graph / unregistered send). */
void
*
custom_ar_reg_buf
;
size_t
custom_ar_reg_sz
;
/**
* Hygon: when commInitAll auto-wires custom allreduce, all ranks share this group for teardown order
* (last destroy frees cudaMalloc bases). Opaque HygonArGroup* in cuda .cu.
*/
void
*
hygon_ar_group
;
bool
hygon_custom_owned
;
};
#define INFINICCL_DEVICE_API(NAMSPACE, IMPL) \
namespace infiniccl::NAMSPACE { \
infiniStatus_t commInitAll( \
infiniDevice_t device_type, \
infinicclComm_t *comms, \
int ndevice, \
const int *device_ids) IMPL; \
...
...
src/infiniccl/kunlun/infiniccl_kunlun.cc
View file @
6a59259c
...
...
@@ -57,6 +57,7 @@ inline BKCLOp getBkclRedOp(infinicclReduceOp_t op) {
namespace
infiniccl
::
kunlun
{
infiniStatus_t
commInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
int
ndevice
,
const
int
*
device_ids
)
{
...
...
@@ -64,7 +65,7 @@ infiniStatus_t commInitAll(
CHECK_BKCL
(
bkcl_comm_init_all
(
bkcl_comms
.
data
(),
ndevice
,
device_ids
));
for
(
int
i
=
0
;
i
<
ndevice
;
i
++
)
{
comms
[
i
]
=
new
InfinicclComm
{
INFINI_DEVICE_KUNLUN
,
device_ids
[
i
],
(
void
*
)(
bkcl_comms
[
i
])};
comms
[
i
]
=
new
InfinicclComm
{
device_type
,
device_ids
[
i
],
(
void
*
)(
bkcl_comms
[
i
])
,
nullptr
,
nullptr
,
0
,
nullptr
,
false
};
}
return
INFINI_STATUS_SUCCESS
;
...
...
src/infiniccl/metax/infiniccl_metax.cc
View file @
6a59259c
...
...
@@ -61,6 +61,7 @@ inline hcclComm_t getHcclComm(infinicclComm_t comm) {
namespace
infiniccl
::
metax
{
infiniStatus_t
commInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
int
ndevice
,
const
int
*
device_ids
)
{
...
...
@@ -69,7 +70,7 @@ infiniStatus_t commInitAll(
CHECK_HCCL
(
hcclCommInitAll
(
hccl_comms
.
data
(),
ndevice
,
(
int
const
*
)
device_ids
));
for
(
int
i
=
0
;
i
<
ndevice
;
i
++
)
{
comms
[
i
]
=
new
InfinicclComm
{
INFINI_DEVICE_METAX
,
device_ids
[
i
],
(
void
*
)(
hccl_comms
[
i
])};
comms
[
i
]
=
new
InfinicclComm
{
device_type
,
device_ids
[
i
],
(
void
*
)(
hccl_comms
[
i
])
,
nullptr
,
nullptr
,
0
,
nullptr
,
false
};
}
return
INFINI_STATUS_SUCCESS
;
...
...
src/infiniccl/moore/infiniccl_moore.cc
View file @
6a59259c
...
...
@@ -60,6 +60,7 @@ inline mcclComm_t getMcclComm(infinicclComm_t comm) {
namespace
infiniccl
::
moore
{
infiniStatus_t
commInitAll
(
infiniDevice_t
device_type
,
infinicclComm_t
*
comms
,
int
ndevice
,
const
int
*
device_ids
)
{
...
...
@@ -68,7 +69,7 @@ infiniStatus_t commInitAll(
CHECK_MCCL
(
mcclCommInitAll
(
mccl_comms
.
data
(),
ndevice
,
(
int
const
*
)
device_ids
));
for
(
int
i
=
0
;
i
<
ndevice
;
i
++
)
{
comms
[
i
]
=
new
InfinicclComm
{
INFINI_DEVICE_MOORE
,
device_ids
[
i
],
(
void
*
)(
mccl_comms
[
i
])};
comms
[
i
]
=
new
InfinicclComm
{
device_type
,
device_ids
[
i
],
(
void
*
)(
mccl_comms
[
i
])
,
nullptr
,
nullptr
,
0
,
nullptr
,
false
};
}
return
INFINI_STATUS_SUCCESS
;
...
...
src/infinicore/ops/mha_kvcache/mha_kvcache_hygon_paged.cc
0 → 100644
View file @
6a59259c
// Hygon DCU decode attention backend.
//
// Overrides the ALLDEVICE flashattn registration so that Hygon uses the
// correct HIP stream guard (TorchStreamGuard) and calls mha_fwd_kvcache.
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace
infinicore
::
op
::
mha_kvcache_impl
::
hygon_paged
{
struct
PlannedMeta
{
graph
::
GraphTensor
out
,
q
,
k_cache
,
v_cache
,
seqlens_k
,
block_table
;
std
::
optional
<
graph
::
GraphTensor
>
alibi_slopes
;
float
scale
;
};
void
*
plan
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k_cache
,
const
Tensor
&
v_cache
,
const
Tensor
&
seqlens_k
,
const
Tensor
&
block_table
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
return
new
PlannedMeta
{
graph
::
GraphTensor
(
out
),
graph
::
GraphTensor
(
q
),
graph
::
GraphTensor
(
k_cache
),
graph
::
GraphTensor
(
v_cache
),
graph
::
GraphTensor
(
seqlens_k
),
graph
::
GraphTensor
(
block_table
),
alibi_slopes
?
std
::
optional
<
graph
::
GraphTensor
>
(
graph
::
GraphTensor
(
*
alibi_slopes
))
:
std
::
nullopt
,
scale
};
}
void
run
(
void
*
planned_meta
)
{
infinicore
::
adaptor
::
TorchStreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
out_tensor
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
);
auto
q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
q
);
auto
k_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k_cache
);
auto
v_cache
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v_cache
);
auto
seqlens_k
=
std
::
optional
<
const
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
seqlens_k
));
auto
block_table
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
block_table
));
auto
alibi_slopes
=
p
->
alibi_slopes
?
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
*
p
->
alibi_slopes
))
:
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
k_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
v_new
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_cos
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
rotary_sin
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
cache_batch_idx
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
leftpad_k
=
std
::
nullopt
;
const
bool
use_dynamic_out
=
q
.
dim
()
==
4
&&
k_cache
.
dim
()
==
4
&&
q
.
size
(
1
)
==
1
&&
q
.
size
(
2
)
>
k_cache
.
size
(
2
)
&&
q
.
size
(
3
)
%
8
==
0
&&
!
alibi_slopes
.
has_value
();
auto
out
=
use_dynamic_out
?
std
::
optional
<
at
::
Tensor
>
(
std
::
nullopt
)
:
std
::
optional
<
at
::
Tensor
>
(
out_tensor
);
auto
result
=
flash
::
mha_fwd_kvcache
(
q
,
k_cache
,
v_cache
,
k_new
,
v_new
,
seqlens_k
,
rotary_cos
,
rotary_sin
,
cache_batch_idx
,
leftpad_k
,
block_table
,
alibi_slopes
,
out
,
p
->
scale
,
true
,
-
1
,
-
1
,
0.0
f
,
false
,
0
);
if
(
use_dynamic_out
)
{
out_tensor
.
copy_
(
result
[
0
]);
}
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
// Register for Hygon device only, overriding the ALLDEVICE flashattn registration.
static
bool
registered
=
[]()
{
MhaKVCache
::
plan_dispatcher
().
registerDevice
(
Device
::
Type
::
HYGON
,
&
plan
,
true
);
MhaKVCache
::
run_dispatcher
().
registerDevice
(
Device
::
Type
::
HYGON
,
&
run
,
true
);
MhaKVCache
::
cleanup_dispatcher
().
registerDevice
(
Device
::
Type
::
HYGON
,
&
cleanup
,
true
);
return
true
;
}();
}
// namespace infinicore::op::mha_kvcache_impl::hygon_paged
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
src/infinicore/ops/multi_head_attention_varlen/mha_varlen_hygon_vllm.cc
0 → 100644
View file @
6a59259c
// Hygon DCU prefill attention backend.
//
// Overrides the ALLDEVICE flashattn registration so that Hygon uses the
// correct HIP stream guard (TorchStreamGuard) and calls mha_varlen_fwd.
#if defined(ENABLE_FLASH_ATTN) && defined(ENABLE_HYGON_API) && !defined(ENABLE_NVIDIA_API)
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/adaptor/flash_attention_adaptor.hpp"
#include <stdexcept>
namespace
infinicore
::
op
::
mha_varlen_impl
::
hygon_vllm
{
struct
PlannedMeta
{
graph
::
GraphTensor
out
,
q
,
k
,
v
,
cum_seqlens_q
,
cum_seqlens_k
,
block_table
;
int
max_seqlen_q
,
max_seqlen_k
;
std
::
optional
<
graph
::
GraphTensor
>
alibi_slopes
;
float
scale
;
};
void
*
plan
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
cum_seqlens_q
,
const
Tensor
&
cum_seqlens_k
,
const
Tensor
&
block_table
,
int
max_seqlen_q
,
int
max_seqlen_k
,
std
::
optional
<
Tensor
>
alibi_slopes
,
float
scale
)
{
return
new
PlannedMeta
{
graph
::
GraphTensor
(
out
),
graph
::
GraphTensor
(
q
),
graph
::
GraphTensor
(
k
),
graph
::
GraphTensor
(
v
),
graph
::
GraphTensor
(
cum_seqlens_q
),
graph
::
GraphTensor
(
cum_seqlens_k
),
graph
::
GraphTensor
(
block_table
),
max_seqlen_q
,
max_seqlen_k
,
alibi_slopes
?
std
::
optional
<
graph
::
GraphTensor
>
(
graph
::
GraphTensor
(
*
alibi_slopes
))
:
std
::
nullopt
,
scale
};
}
void
run
(
void
*
planned_meta
)
{
infinicore
::
adaptor
::
TorchStreamGuard
guard
(
infinicore
::
adaptor
::
get_cuda_stream
());
auto
*
p
=
reinterpret_cast
<
PlannedMeta
*>
(
planned_meta
);
auto
q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
q
);
auto
k
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
k
);
auto
v
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
v
);
auto
out
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
out
));
auto
cu_seqlens_q
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
cum_seqlens_q
);
auto
cu_seqlens_kv
=
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
cum_seqlens_k
);
auto
block_table
=
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
p
->
block_table
));
auto
device
=
q
.
device
();
if
(
!
cu_seqlens_q
.
is_cuda
())
cu_seqlens_q
=
cu_seqlens_q
.
to
(
device
);
if
(
!
cu_seqlens_kv
.
is_cuda
())
cu_seqlens_kv
=
cu_seqlens_kv
.
to
(
device
);
if
(
block_table
.
has_value
()
&&
!
block_table
->
is_cuda
())
block_table
=
block_table
->
to
(
device
);
std
::
optional
<
at
::
Tensor
>
seqused_k
=
std
::
nullopt
;
std
::
optional
<
const
at
::
Tensor
>
leftpad_k
=
std
::
nullopt
;
auto
alibi_slopes
=
p
->
alibi_slopes
?
std
::
optional
<
at
::
Tensor
>
(
infinicore
::
adaptor
::
to_aten_tensor
(
*
p
->
alibi_slopes
))
:
std
::
nullopt
;
flash
::
mha_varlen_fwd
(
q
,
k
,
v
,
out
,
cu_seqlens_q
,
cu_seqlens_kv
,
seqused_k
,
leftpad_k
,
block_table
,
alibi_slopes
,
p
->
max_seqlen_q
,
p
->
max_seqlen_k
,
0.0
f
,
p
->
scale
,
false
,
true
,
-
1
,
-
1
,
0.0
f
,
false
,
std
::
nullopt
);
}
void
cleanup
(
void
**
planned_meta_ptr
)
{
delete
*
reinterpret_cast
<
PlannedMeta
**>
(
planned_meta_ptr
);
*
planned_meta_ptr
=
nullptr
;
}
// Register for Hygon device only, overriding the ALLDEVICE flashattn registration.
static
bool
registered
=
[]()
{
MultiheadAttentionVarlen
::
plan_dispatcher
().
registerDevice
(
Device
::
Type
::
HYGON
,
&
plan
,
true
);
MultiheadAttentionVarlen
::
run_dispatcher
().
registerDevice
(
Device
::
Type
::
HYGON
,
&
run
,
true
);
MultiheadAttentionVarlen
::
cleanup_dispatcher
().
registerDevice
(
Device
::
Type
::
HYGON
,
&
cleanup
,
true
);
return
true
;
}();
}
// namespace infinicore::op::mha_varlen_impl::hygon_vllm
#endif // ENABLE_FLASH_ATTN && ENABLE_HYGON_API && !ENABLE_NVIDIA_API
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment