Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
784139b9
Unverified
Commit
784139b9
authored
Feb 13, 2026
by
thatPepe
Committed by
GitHub
Feb 13, 2026
Browse files
Merge pull request #990 from InfiniTensor/demo131
Demo-131 Cuda graph with optimized paged attention
parents
3c8fb3c0
1d6527cb
Changes
582
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
186 additions
and
47 deletions
+186
-47
.gitmodules
.gitmodules
+4
-0
README.md
README.md
+6
-4
include/infinicore.h
include/infinicore.h
+1
-0
include/infinicore.hpp
include/infinicore.hpp
+1
-0
include/infinicore/device.hpp
include/infinicore/device.hpp
+1
-0
include/infinicore/graph/graph.hpp
include/infinicore/graph/graph.hpp
+14
-9
include/infinicore/nn/linear.hpp
include/infinicore/nn/linear.hpp
+22
-0
include/infinicore/nn/rmsnorm.hpp
include/infinicore/nn/rmsnorm.hpp
+19
-4
include/infinicore/ops.hpp
include/infinicore/ops.hpp
+4
-0
include/infinicore/ops/add.hpp
include/infinicore/ops/add.hpp
+6
-9
include/infinicore/ops/add_rms_norm.hpp
include/infinicore/ops/add_rms_norm.hpp
+6
-8
include/infinicore/ops/causal_softmax.hpp
include/infinicore/ops/causal_softmax.hpp
+6
-8
include/infinicore/ops/dequantize_awq.hpp
include/infinicore/ops/dequantize_awq.hpp
+10
-0
include/infinicore/ops/distributed/allreduce.hpp
include/infinicore/ops/distributed/allreduce.hpp
+24
-0
include/infinicore/ops/embedding.hpp
include/infinicore/ops/embedding.hpp
+6
-2
include/infinicore/ops/flash_attention.hpp
include/infinicore/ops/flash_attention.hpp
+12
-0
include/infinicore/ops/gemm.hpp
include/infinicore/ops/gemm.hpp
+3
-3
include/infinicore/ops/kv_caching.hpp
include/infinicore/ops/kv_caching.hpp
+16
-0
include/infinicore/ops/linear_w4a16_awq.hpp
include/infinicore/ops/linear_w4a16_awq.hpp
+12
-0
include/infinicore/ops/linear_w8a8i8.hpp
include/infinicore/ops/linear_w8a8i8.hpp
+13
-0
No files found.
.gitmodules
View file @
784139b9
[submodule "third_party/spdlog"]
[submodule "third_party/spdlog"]
path = third_party/spdlog
path = third_party/spdlog
url = https://github.com/gabime/spdlog.git
url = https://github.com/gabime/spdlog.git
[submodule "third_party/nlohmann_json"]
path = third_party/nlohmann_json
url = https://github.com/nlohmann/json.git
branch = master
README.md
View file @
784139b9
...
@@ -20,6 +20,7 @@ InfiniCore 是一个跨平台统一编程工具集,为不同芯片平台的功
...
@@ -20,6 +20,7 @@ InfiniCore 是一个跨平台统一编程工具集,为不同芯片平台的功
-
天数智芯 GPU;
-
天数智芯 GPU;
-
沐曦 GPU;
-
沐曦 GPU;
-
海光 DCU;
-
海光 DCU;
-
阿里 PPU;
-
华为昇腾 NPU;
-
华为昇腾 NPU;
-
寒武纪 MLU;
-
寒武纪 MLU;
-
昆仑芯 XPU;
-
昆仑芯 XPU;
...
@@ -103,6 +104,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS]
...
@@ -103,6 +104,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS]
|
`--qy-gpu=[y\|n]`
| 是否编译QY GPU 接口实现 | n
|
`--qy-gpu=[y\|n]`
| 是否编译QY GPU 接口实现 | n
|
`--hygon-dcu=[y\|n]`
| 是否编译海光 DCU 接口实现 | n
|
`--hygon-dcu=[y\|n]`
| 是否编译海光 DCU 接口实现 | n
|
`--kunlun-xpu=[y\|n]`
| 是否编译昆仑 XPU 接口实现 | n
|
`--kunlun-xpu=[y\|n]`
| 是否编译昆仑 XPU 接口实现 | n
|
`--ali-ppu=[y\|n]`
| 是否编译阿里 PPU 接口实现 | n
|
`--ninetoothed=[y\|n]`
| 是否编译九齿实现 | n
|
`--ninetoothed=[y\|n]`
| 是否编译九齿实现 | n
|
`--ccl=[y\|n]`
| 是否编译 InfiniCCL 通信库接口实现 | n
|
`--ccl=[y\|n]`
| 是否编译 InfiniCCL 通信库接口实现 | n
...
@@ -187,9 +189,9 @@ pip install -e .
...
@@ -187,9 +189,9 @@ pip install -e .
```
bash
```
bash
# 测试单算子
# 测试单算子
python
test
/infinicore/ops/[operator].py
[
--bench
|
--debug
|
--verbose
]
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--Hygon
]
python
test
/infinicore/ops/[operator].py
[
--bench
|
--debug
|
--verbose
]
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--Hygon
|
--ali
]
# 测试全部算子
# 测试全部算子
python
test
/infinicore/run.py
[
--bench
|
--debug
|
--verbose
]
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
]
python
test
/infinicore/run.py
[
--bench
|
--debug
|
--verbose
]
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--ali
]
```
```
使用 -h 查看更多参数。
使用 -h 查看更多参数。
...
@@ -198,9 +200,9 @@ python test/infinicore/run.py [--bench | --debug | --verbose] [--cpu | --nvidia
...
@@ -198,9 +200,9 @@ python test/infinicore/run.py [--bench | --debug | --verbose] [--cpu | --nvidia
```
shell
```
shell
# 测试单算子
# 测试单算子
python
test
/infiniop/[operator].py
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--Hygon
]
python
test
/infiniop/[operator].py
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--Hygon
|
--ali
]
# 测试全部算子
# 测试全部算子
python scripts/python_test.py
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--Hygon
]
python scripts/python_test.py
[
--cpu
|
--nvidia
|
--cambricon
|
--ascend
|
--iluvatar
|
--metax
|
--moore
|
--kunlun
|
--Hygon
|
--ali
]
```
```
#### 通信库(InfiniCCL)测试
#### 通信库(InfiniCCL)测试
...
...
include/infinicore.h
View file @
784139b9
...
@@ -47,6 +47,7 @@ typedef enum {
...
@@ -47,6 +47,7 @@ typedef enum {
INFINI_DEVICE_KUNLUN
=
7
,
INFINI_DEVICE_KUNLUN
=
7
,
INFINI_DEVICE_HYGON
=
8
,
INFINI_DEVICE_HYGON
=
8
,
INFINI_DEVICE_QY
=
9
,
INFINI_DEVICE_QY
=
9
,
INFINI_DEVICE_ALI
=
10
,
INFINI_DEVICE_TYPE_COUNT
INFINI_DEVICE_TYPE_COUNT
}
infiniDevice_t
;
}
infiniDevice_t
;
...
...
include/infinicore.hpp
View file @
784139b9
...
@@ -3,4 +3,5 @@
...
@@ -3,4 +3,5 @@
#include "infinicore/device_event.hpp"
#include "infinicore/device_event.hpp"
#include "infinicore/nn.hpp"
#include "infinicore/nn.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/quantization.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/tensor.hpp"
include/infinicore/device.hpp
View file @
784139b9
...
@@ -22,6 +22,7 @@ public:
...
@@ -22,6 +22,7 @@ public:
KUNLUN
=
INFINI_DEVICE_KUNLUN
,
KUNLUN
=
INFINI_DEVICE_KUNLUN
,
HYGON
=
INFINI_DEVICE_HYGON
,
HYGON
=
INFINI_DEVICE_HYGON
,
QY
=
INFINI_DEVICE_QY
,
QY
=
INFINI_DEVICE_QY
,
ALI
=
INFINI_DEVICE_ALI
,
COUNT
=
INFINI_DEVICE_TYPE_COUNT
,
COUNT
=
INFINI_DEVICE_TYPE_COUNT
,
};
};
...
...
include/infinicore/graph/graph.hpp
View file @
784139b9
...
@@ -15,10 +15,15 @@ public:
...
@@ -15,10 +15,15 @@ public:
};
};
class
GraphOperator
{
class
GraphOperator
{
public:
virtual
void
run
()
const
=
0
;
virtual
~
GraphOperator
()
=
default
;
};
class
DispatchableGraphOperator
:
public
GraphOperator
{
public:
public:
void
run
()
const
;
void
run
()
const
override
;
~
GraphOperator
();
~
Dispatchable
GraphOperator
()
override
;
protected:
protected:
using
run_schema
=
void
(
*
)(
void
*
);
using
run_schema
=
void
(
*
)(
void
*
);
...
@@ -49,7 +54,7 @@ private:
...
@@ -49,7 +54,7 @@ private:
}
// namespace infinicore::graph
}
// namespace infinicore::graph
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator {
\
class __OP_NAME__ : public graph::
Dispatchable
GraphOperator { \
public: \
public: \
using schema = void (*)(__VA_ARGS__); \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
...
@@ -79,12 +84,12 @@ private:
...
@@ -79,12 +84,12 @@ private:
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...)
\
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__);
\
auto
___
op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
if (context::isGraphRecording()) {
\
context::addGraphOperator(op);
\
context::addGraphOperator(
___
op); \
} else { \
} else {
\
op->run();
\
___
op->run(); \
}
}
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
...
...
include/infinicore/nn/linear.hpp
View file @
784139b9
#pragma once
#pragma once
#include "../ops.hpp"
#include "../ops.hpp"
#include "../quantization.hpp"
#include "module.hpp"
#include "module.hpp"
#include <infiniccl.h>
#include <infiniccl.h>
#include <optional>
namespace
infinicore
::
nn
{
namespace
infinicore
::
nn
{
...
@@ -11,6 +13,9 @@ public:
...
@@ -11,6 +13,9 @@ public:
BaseLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
BaseLinear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
BaseLinear
(
size_t
in_features
,
size_t
out_features
,
std
::
shared_ptr
<
infinicore
::
quantization
::
BaseQuantization
>
quantization
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
// Forward pass: output = input @ weight.T + bias
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
Tensor
forward
(
Tensor
&
input
)
const
;
...
@@ -27,12 +32,17 @@ public:
...
@@ -27,12 +32,17 @@ public:
// Accessors for parameters
// Accessors for parameters
Tensor
weight
()
const
{
return
weight_
;
}
Tensor
weight
()
const
{
return
weight_
;
}
Tensor
bias
()
const
{
return
bias_
;
}
Tensor
bias
()
const
{
return
bias_
;
}
Tensor
weight_scale
()
const
{
return
weight_scale_
;
}
Tensor
weight_zeros
()
const
{
return
weight_zeros_
;
}
protected:
protected:
// Parameters
// Parameters
INFINICORE_NN_PARAMETER
(
weight
);
INFINICORE_NN_PARAMETER
(
weight
);
INFINICORE_NN_PARAMETER
(
bias
);
INFINICORE_NN_PARAMETER
(
bias
);
INFINICORE_NN_PARAMETER
(
weight_scale
);
INFINICORE_NN_PARAMETER
(
weight_zeros
);
protected:
protected:
// Helper method for common forward computation
// Helper method for common forward computation
Tensor
compute_linear
(
Tensor
&
input
)
const
;
Tensor
compute_linear
(
Tensor
&
input
)
const
;
...
@@ -41,6 +51,7 @@ protected:
...
@@ -41,6 +51,7 @@ protected:
size_t
out_features_
;
size_t
out_features_
;
bool
has_bias_
;
bool
has_bias_
;
DataType
dtype_
;
DataType
dtype_
;
std
::
shared_ptr
<
infinicore
::
quantization
::
BaseQuantization
>
quantization_
=
std
::
make_shared
<
infinicore
::
quantization
::
NoneQuantization
>
(
nullptr
);
};
};
}
// namespace infinicore::nn
}
// namespace infinicore::nn
...
@@ -52,6 +63,9 @@ public:
...
@@ -52,6 +63,9 @@ public:
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
Linear
(
size_t
in_features
,
size_t
out_features
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
Linear
(
size_t
in_features
,
size_t
out_features
,
std
::
shared_ptr
<
infinicore
::
quantization
::
BaseQuantization
>
quantization
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
// Forward pass: output = input @ weight.T + bias
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
Tensor
forward
(
Tensor
&
input
)
const
;
...
@@ -65,6 +79,10 @@ public:
...
@@ -65,6 +79,10 @@ public:
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
Size
tp_rank
=
0
,
Size
tp_size
=
1
);
Size
tp_rank
=
0
,
Size
tp_size
=
1
);
ColumnParallelLinear
(
size_t
in_features
,
size_t
out_features
,
std
::
shared_ptr
<
infinicore
::
quantization
::
BaseQuantization
>
quantization
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
Size
tp_rank
=
0
,
Size
tp_size
=
1
);
// Forward pass: output = input @ weight.T + bias
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
Tensor
forward
(
Tensor
&
input
)
const
;
...
@@ -82,6 +100,10 @@ public:
...
@@ -82,6 +100,10 @@ public:
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
Size
tp_rank
=
0
,
Size
tp_size
=
1
,
infinicclComm_t
communicator
=
nullptr
);
Size
tp_rank
=
0
,
Size
tp_size
=
1
,
infinicclComm_t
communicator
=
nullptr
);
RowParallelLinear
(
size_t
in_features
,
size_t
out_features
,
std
::
shared_ptr
<
infinicore
::
quantization
::
BaseQuantization
>
quantization
,
bool
bias
=
true
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
(),
Size
tp_rank
=
0
,
Size
tp_size
=
1
,
infinicclComm_t
communicator
=
nullptr
);
// Forward pass: output = input @ weight.T + bias
// Forward pass: output = input @ weight.T + bias
Tensor
forward
(
Tensor
&
input
)
const
;
Tensor
forward
(
Tensor
&
input
)
const
;
...
...
include/infinicore/nn/rmsnorm.hpp
View file @
784139b9
#pragma once
#pragma once
#include "module.hpp"
#include "../ops.hpp"
#include "../ops.hpp"
#include "module.hpp"
namespace
infinicore
::
nn
{
namespace
infinicore
::
nn
{
...
@@ -57,6 +57,21 @@ public:
...
@@ -57,6 +57,21 @@ public:
*/
*/
Tensor
forward
(
const
Tensor
&
x
)
const
;
Tensor
forward
(
const
Tensor
&
x
)
const
;
/**
* @brief Forward pass: apply RMSNorm in-place with residual
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
* Will be modified in-place to the normalized output.
* @param residual Residual tensor to add to input before normalization.
* Will be modified in-place to the sum of input and residual.
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
void
forward_inplace
(
Tensor
&
x
,
Tensor
&
residual
)
const
;
// Module information
// Module information
size_t
normalized_shape
()
const
{
return
normalized_shape_
;
}
size_t
normalized_shape
()
const
{
return
normalized_shape_
;
}
double
eps
()
const
{
return
eps_
;
}
double
eps
()
const
{
return
eps_
;
}
...
@@ -73,9 +88,9 @@ protected:
...
@@ -73,9 +88,9 @@ protected:
INFINICORE_NN_PARAMETER
(
weight
);
INFINICORE_NN_PARAMETER
(
weight
);
private:
private:
size_t
normalized_shape_
;
// Size of the feature dimension
size_t
normalized_shape_
;
// Size of the feature dimension
double
eps_
;
// Epsilon for numerical stability
double
eps_
;
// Epsilon for numerical stability
DataType
dtype_
;
// Data type for weight
DataType
dtype_
;
// Data type for weight
};
};
}
// namespace infinicore::nn
}
// namespace infinicore::nn
include/infinicore/ops.hpp
View file @
784139b9
...
@@ -4,6 +4,9 @@
...
@@ -4,6 +4,9 @@
#include "ops/add_rms_norm.hpp"
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/matmul.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention.hpp"
...
@@ -14,4 +17,5 @@
...
@@ -14,4 +17,5 @@
#include "ops/rms_norm.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/silu.hpp"
#include "ops/silu_and_mul.hpp"
#include "ops/swiglu.hpp"
#include "ops/swiglu.hpp"
include/infinicore/ops/add.hpp
View file @
784139b9
#pragma once
#pragma once
#include "../device.hpp"
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
class
Add
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
);
static
void
execute
(
Tensor
c
,
Tensor
a
,
Tensor
b
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
add
(
Tensor
a
,
Tensor
b
);
INFINICORE_GRAPH_OP_CLASS
(
Add
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
);
void
add_
(
Tensor
c
,
Tensor
a
,
Tensor
b
);
Tensor
operator
+
(
Tensor
a
,
Tensor
b
);
Tensor
add
(
const
Tensor
&
a
,
const
Tensor
&
b
);
void
add_
(
Tensor
c
,
const
Tensor
&
a
,
const
Tensor
&
b
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infinicore/ops/add_rms_norm.hpp
View file @
784139b9
...
@@ -5,16 +5,14 @@
...
@@ -5,16 +5,14 @@
#include <utility>
#include <utility>
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
class
AddRMSNorm
{
INFINICORE_GRAPH_OP_CLASS
(
AddRMSNorm
,
Tensor
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
,
float
);
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
float
);
static
void
execute
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
// Fused Add and RMS Normalization
// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
// The add_result can be used as residual for subsequent layers
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
std
::
pair
<
Tensor
,
Tensor
>
add_rms_norm
(
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
y
,
Tensor
residual_out
,
Tensor
a
,
Tensor
b
,
Tensor
weight
,
float
epsilon
=
1e-5
f
);
void
add_rms_norm_
(
Tensor
out
,
Tensor
residual
,
const
Tensor
&
a
,
const
Tensor
&
b
,
const
Tensor
&
weight
,
float
epsilon
=
1e-5
f
);
// Fused Add and RMS Normalization (inplace)
// normalized_result wil be stored in input, add_result will be stored in residual
void
add_rms_norm_inplace
(
Tensor
input
,
Tensor
residual
,
const
Tensor
&
weight
,
float
epsilon
=
1e-5
f
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infinicore/ops/causal_softmax.hpp
View file @
784139b9
#pragma once
#pragma once
#include "../device.hpp"
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
class
CausalSoftmax
{
public:
using
schema
=
void
(
*
)(
Tensor
,
Tensor
);
static
void
execute
(
Tensor
output
,
Tensor
input
);
static
common
::
OpDispatcher
<
schema
>
&
dispatcher
();
};
Tensor
causal_softmax
(
Tensor
input
);
INFINICORE_GRAPH_OP_CLASS
(
CausalSoftmax
,
Tensor
,
const
Tensor
&
);
void
causal_softmax_
(
Tensor
output
,
Tensor
input
);
Tensor
causal_softmax
(
const
Tensor
&
input
);
void
causal_softmax_
(
Tensor
output
,
const
Tensor
&
input
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infinicore/ops/dequantize_awq.hpp
0 → 100644
View file @
784139b9
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_CLASS
(
DequantizeAWQ
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
);
void
dequantize_awq_
(
Tensor
x
,
const
Tensor
&
x_packed
,
const
Tensor
&
x_scale
,
const
Tensor
&
x_zeros
);
}
// namespace infinicore::op
include/infinicore/ops/distributed/allreduce.hpp
0 → 100644
View file @
784139b9
#pragma once
#include "../../device.hpp"
#include "../../graph/graph.hpp"
#include "../common/op.hpp"
#include <infiniccl.h>
namespace
infinicore
::
op
::
distributed
{
class
AllReduce
:
public
graph
::
GraphOperator
{
public:
AllReduce
(
Tensor
output
,
const
Tensor
&
input
,
infinicclReduceOp_t
op
,
infinicclComm_t
communicator
);
~
AllReduce
();
void
run
()
const
override
;
static
void
execute
(
Tensor
output
,
const
Tensor
&
input
,
infinicclReduceOp_t
op
,
infinicclComm_t
communicator
);
private:
void
*
planned_meta_
;
};
Tensor
allreduce
(
const
Tensor
&
input
,
infinicclReduceOp_t
op
,
infinicclComm_t
communicator
);
void
allreduce_
(
Tensor
output
,
const
Tensor
&
input
,
infinicclReduceOp_t
op
,
infinicclComm_t
communicator
);
}
// namespace infinicore::op::distributed
include/infinicore/ops/embedding.hpp
View file @
784139b9
#pragma once
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
Tensor
embedding
(
Tensor
input
,
Tensor
weight
);
INFINICORE_GRAPH_OP_CLASS
(
Embedding
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
);
void
embedding_
(
Tensor
out
,
Tensor
input
,
Tensor
weight
);
Tensor
embedding
(
const
Tensor
&
input
,
const
Tensor
&
weight
);
void
embedding_
(
Tensor
out
,
const
Tensor
&
input
,
const
Tensor
&
weight
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infinicore/ops/flash_attention.hpp
0 → 100644
View file @
784139b9
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_CLASS
(
FlashAttention
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
,
float
,
bool
);
Tensor
flash_attention
(
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
);
void
flash_attention_
(
Tensor
out
,
const
Tensor
&
q
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
total_kv_len
,
float
scale
,
bool
is_causal
);
}
// namespace infinicore::op
include/infinicore/ops/gemm.hpp
View file @
784139b9
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
namespace
infinicore
::
op
{
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_CLASS
(
Gemm
,
Tensor
,
Tensor
,
Tensor
,
float
,
float
);
INFINICORE_GRAPH_OP_CLASS
(
Gemm
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
float
,
float
);
Tensor
gemm
(
Tensor
a
,
Tensor
b
,
float
alpha
=
1.0
f
,
float
beta
=
0.0
f
);
Tensor
gemm
(
const
Tensor
&
a
,
const
Tensor
&
b
,
float
alpha
=
1.0
f
,
float
beta
=
0.0
f
);
void
gemm_
(
Tensor
c
,
Tensor
a
,
Tensor
b
,
float
alpha
,
float
beta
);
void
gemm_
(
Tensor
c
,
const
Tensor
&
a
,
const
Tensor
&
b
,
float
alpha
,
float
beta
);
}
// namespace infinicore::op
}
// namespace infinicore::op
include/infinicore/ops/kv_caching.hpp
0 → 100644
View file @
784139b9
#pragma once
#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_CLASS
(
KVCaching
,
Tensor
,
Tensor
,
const
Tensor
&
,
const
Tensor
&
,
const
Tensor
&
);
void
kv_caching_
(
Tensor
k_cache
,
Tensor
v_cache
,
const
Tensor
&
k
,
const
Tensor
&
v
,
const
Tensor
&
past_kv_lengths
);
}
// namespace infinicore::op
include/infinicore/ops/linear_w4a16_awq.hpp
0 → 100644
View file @
784139b9
#pragma once
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
Tensor
linear_w4a16_awq
(
Tensor
input
,
Tensor
weight_packed
,
Tensor
weight_scale
,
Tensor
weight_zeros
,
std
::
optional
<
Tensor
>
bias
);
void
linear_w4a16_awq_
(
Tensor
out
,
Tensor
input
,
Tensor
weight_packed
,
Tensor
weight_scale
,
Tensor
weight_zeros
,
std
::
optional
<
Tensor
>
bias
);
}
// namespace infinicore::op
include/infinicore/ops/linear_w8a8i8.hpp
0 → 100644
View file @
784139b9
#pragma once
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>
namespace
infinicore
::
op
{
Tensor
linear_w8a8i8
(
Tensor
input
,
Tensor
weight_packed
,
Tensor
weight_scale
,
std
::
optional
<
Tensor
>
bias
);
void
linear_w8a8i8_
(
Tensor
out
,
Tensor
input
,
Tensor
weight_packed
,
Tensor
weight_scale
,
std
::
optional
<
Tensor
>
bias
);
}
// namespace infinicore::op
Prev
1
2
3
4
5
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment