Unverified Commit d09de04c authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #250 from InfiniTensor/issue/248

Issue/248 support flash-attention
parents f67956fe 5dc85bf4
......@@ -252,6 +252,13 @@ def get_args():
action="store_true",
help="Perform a warmup run before benchmarking/inference.",
)
parser.add_argument(
"--attn",
type=str,
default="default",
choices=["default", "flash-attn"],
help="attention backend to use: 'default' or 'flash-attn'",
)
return parser.parse_args()
......@@ -278,6 +285,7 @@ class TestModel:
skip_load=False,
cache_config=None,
enable_graph=False,
attn_backend="default",
) -> None:
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -289,6 +297,7 @@ class TestModel:
distributed_config=DistConfig(tp),
cache_config=cache_config,
enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
)
# ---------------------------------------------------------------------------- #
......@@ -461,6 +470,7 @@ if __name__ == "__main__":
skip_load=skip_load,
cache_config=cache_config,
enable_graph=enable_graph,
attn_backend=args.attn,
)
# ---------------------------------------------------------------------------- #
......
......@@ -142,6 +142,14 @@ def get_args():
help="sampling temperature",
)
parser.add_argument(
"--attn",
type=str,
default="default",
choices=["default", "flash-attn"],
help="attention backend to use: 'default' or 'flash-attn'",
)
return parser.parse_args()
......@@ -156,6 +164,7 @@ def test(
top_k=1,
top_p=1.0,
temperature=1.0,
attn_backend="default",
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -166,6 +175,7 @@ def test(
device=infini_device,
distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph,
attention_backend=attn_backend,
)
# ---------------------------------------------------------------------------- #
# Load Weights
......@@ -333,4 +343,5 @@ if __name__ == "__main__":
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
attn_backend=args.attn,
)
......@@ -3,7 +3,7 @@
#include <infinirt.h>
__C __export struct KVCache *createKVCache(
__INFINI_C __export struct KVCache *createKVCache(
size_t nlayers,
size_t max_len,
size_t nkvh_,
......@@ -14,8 +14,8 @@ __C __export struct KVCache *createKVCache(
int *dev_ids,
size_t ndev);
__C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len);
__INFINI_C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len);
__C __export void dropKVCache(KVCache *kv_cache);
__INFINI_C __export void dropKVCache(KVCache *kv_cache);
#endif /* CACHE_H */
......@@ -103,26 +103,26 @@ typedef struct {
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct DeepSeekV3Model *
__INFINI_C __export struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *,
const DeepSeekV3Weights *);
__C DeepSeekV3Weights *
__INFINI_C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids);
__C __export DeepSeekV3WeightLoader *
__INFINI_C __export DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader();
/// @brief 销毁模型
__C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *);
__INFINI_C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *);
__C __export struct DeepSeekV3Cache *
__INFINI_C __export struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *);
__C __export void
__INFINI_C __export void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
struct DeepSeekV3Cache *);
......@@ -137,7 +137,7 @@ dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
__INFINI_C __export void
inferBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -153,7 +153,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *,
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
__INFINI_C __export void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......
......@@ -54,7 +54,7 @@ typedef struct
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeModel *
__INFINI_C __export struct JiugeModel *
createJiugeModel(const JiugeMeta *,
const JiugeWeights *,
infiniDevice_t device,
......@@ -62,7 +62,7 @@ createJiugeModel(const JiugeMeta *,
const int *dev_ids);
/// @brief 销毁模型
__C __export void
__INFINI_C __export void
destroyJiugeModel(struct JiugeModel *);
/// @brief 批次推理一轮,并采样出新的 token
......@@ -76,7 +76,7 @@ destroyJiugeModel(struct JiugeModel *);
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
__INFINI_C __export void
inferBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -92,7 +92,7 @@ inferBatchJiuge(struct JiugeModel *,
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
__INFINI_C __export void
forwardBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......
......@@ -25,7 +25,7 @@ typedef struct
} JiugeAWQMeta;
//////////////////// APIs ///////////////////////
__C __export struct ModelWeights *
__INFINI_C __export struct ModelWeights *
createJiugeAWQWeights(const JiugeAWQMeta *,
infiniDevice_t device,
int ndev,
......@@ -34,12 +34,12 @@ createJiugeAWQWeights(const JiugeAWQMeta *,
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeAWQModel *
__INFINI_C __export struct JiugeAWQModel *
createJiugeAWQModel(const JiugeAWQMeta *,
const ModelWeights *);
/// @brief 销毁模型
__C __export void
__INFINI_C __export void
destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @brief 批次推理一轮,并采样出新的 token
......@@ -53,7 +53,7 @@ destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
__INFINI_C __export void
inferBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -69,7 +69,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *,
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
__INFINI_C __export void
forwardBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......
......@@ -5,10 +5,10 @@
struct ModelWeights;
__C __export void
__INFINI_C __export void
loadModelWeight(struct ModelWeights *weights, const char *name, void *data);
__C __export void
__INFINI_C __export void
loadModelWeightDistributed(struct ModelWeights *weights, const char *name, void *data, int *ranks, int nrank);
#endif // WEIGHTS_LOADER_H
......@@ -29,6 +29,7 @@ class InferEngine(_infinilm.InferEngine):
distributed_config=DistConfig(1),
cache_config=None,
enable_graph_compiling=False,
attention_backend="default",
):
self.config = AutoConfig.from_pretrained(model_path)
......@@ -41,6 +42,7 @@ class InferEngine(_infinilm.InferEngine):
device._underlying.type,
cache_config,
enable_graph_compiling,
attention_backend,
)
self.use_cache = False
......@@ -57,6 +59,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=None,
total_kv_lengths=None,
input_offsets=None,
cu_seqlens=None,
block_tables=None,
slot_mapping=None,
temperature=None,
......@@ -74,6 +77,7 @@ class InferEngine(_infinilm.InferEngine):
)
input_offsets = input_offsets._underlying if input_offsets is not None else None
block_tables = block_tables._underlying if block_tables is not None else None
cu_seqlens = cu_seqlens._underlying if cu_seqlens is not None else None
slot_mapping = slot_mapping._underlying if slot_mapping is not None else None
return infinicore.Tensor(
......@@ -85,6 +89,7 @@ class InferEngine(_infinilm.InferEngine):
past_sequence_lengths=past_kv_lengths,
total_sequence_lengths=total_kv_lengths,
input_offsets=input_offsets,
cu_seqlens=cu_seqlens,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=temperature,
......@@ -135,7 +140,7 @@ class InferEngine(_infinilm.InferEngine):
]
block_tables = infinicore.from_list(
block_tables_list,
dtype=infinicore.int64,
dtype=infinicore.int32,
)
for iter in range(0, generation_config.max_new_tokens):
......@@ -188,14 +193,17 @@ class InferEngine(_infinilm.InferEngine):
slot_mapping = None
past_kv_lengths = infinicore.from_list(
[past_seq_len] * batch_size, dtype=infinicore.int64
[past_seq_len] * batch_size, dtype=infinicore.int32
)
total_kv_lengths = infinicore.from_list(
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int64
[past_seq_len + seq_len] * batch_size, dtype=infinicore.int32
)
cu_seqlens = infinicore.from_list(
[(past_seq_len + seq_len) * i for i in range(batch_size + 1)],
dtype=infinicore.int32,
)
input_offsets = infinicore.from_list(
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int64
[seq_len * i for i in range(batch_size + 1)], dtype=infinicore.int32
)
output_id = self(
......@@ -204,6 +212,7 @@ class InferEngine(_infinilm.InferEngine):
past_kv_lengths=past_kv_lengths,
total_kv_lengths=total_kv_lengths,
input_offsets=input_offsets,
cu_seqlens=cu_seqlens,
block_tables=block_tables,
slot_mapping=slot_mapping,
temperature=generation_config.temperature,
......
#include "../cache.hpp"
__C struct KVCache *createKVCache(
__INFINI_C struct KVCache *createKVCache(
size_t nlayers,
size_t max_len,
size_t nkvh_,
......@@ -31,7 +31,7 @@ __C struct KVCache *createKVCache(
return cache;
}
__C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
__INFINI_C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType();
......@@ -65,7 +65,7 @@ __C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
return new_kv_cache;
}
__C void dropKVCache(KVCache *kv_cache) {
__INFINI_C void dropKVCache(KVCache *kv_cache) {
auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType();
......
......@@ -78,7 +78,7 @@ std::shared_ptr<Tensor> Loader::get(const std::string &name, int rank) {
} // namespace infinicore::weights
__C void
__INFINI_C void
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
std::string name_str(name);
auto weights = reinterpret_cast<infinicore::weights::Loader *>(weights_);
......
......@@ -431,7 +431,7 @@ void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc
}
}
__C void
__INFINI_C void
inferBatchDeepSeekV3(struct DeepSeekV3Model *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -464,7 +464,7 @@ inferBatchDeepSeekV3(struct DeepSeekV3Model *model,
}
}
__C void
__INFINI_C void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -563,14 +563,14 @@ DeepSeekV3Model::DeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3We
}
}
__C struct DeepSeekV3Model *
__INFINI_C struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *_meta,
const DeepSeekV3Weights *weights) {
DeepSeekV3Model *model = new DeepSeekV3Model(_meta, weights);
return model;
}
__C void
__INFINI_C void
destroyDeepSeekV3Model(struct DeepSeekV3Model *model) {
auto ndev = model->dev_resources.size();
......
#include "deepseek_v3_impl.hpp"
__C struct DeepSeekV3Cache *
__INFINI_C struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
DeepSeekV3Cache *cache = new DeepSeekV3Cache();
auto ndev = model->dev_resources.size();
......@@ -25,7 +25,7 @@ createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
return cache;
}
__C void
__INFINI_C void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *model,
struct DeepSeekV3Cache *cache) {
auto ndev = model->dev_resources.size();
......
......@@ -436,7 +436,7 @@ static DeepSeekV3WeightLoader weight_loader = {
.load_mlp_experts = load_mlp_experts,
};
__C DeepSeekV3Weights *
__INFINI_C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
......@@ -445,7 +445,7 @@ createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
return weights;
};
__C DeepSeekV3WeightLoader *
__INFINI_C DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader() {
return &weight_loader;
}
......@@ -315,7 +315,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc,
}
}
__C void
__INFINI_C void
inferBatchJiuge(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -348,7 +348,7 @@ inferBatchJiuge(struct JiugeModel *model,
}
}
__C void
__INFINI_C void
forwardBatchJiuge(struct JiugeModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -444,7 +444,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi
}
}
__C struct JiugeModel *
__INFINI_C struct JiugeModel *
createJiugeModel(const JiugeMeta *meta,
const JiugeWeights *weights,
infiniDevice_t device,
......@@ -456,7 +456,7 @@ createJiugeModel(const JiugeMeta *meta,
return model;
}
__C void destroyJiugeModel(struct JiugeModel *model) {
__INFINI_C void destroyJiugeModel(struct JiugeModel *model) {
auto ndev = model->dev_resources.size();
for (size_t idev = 0; idev < ndev; idev++) {
......
......@@ -242,7 +242,7 @@ void inferDeviceBatch(const JiugeAWQMeta *meta, DeviceResource &rsrc,
}
}
__C void
__INFINI_C void
inferBatchJiugeAWQ(struct JiugeAWQModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -275,7 +275,7 @@ inferBatchJiugeAWQ(struct JiugeAWQModel *model,
}
}
__C void
__INFINI_C void
forwardBatchJiugeAWQ(struct JiugeAWQModel *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
......@@ -372,14 +372,14 @@ JiugeAWQModel::JiugeAWQModel(const JiugeAWQMeta *meta, const ModelWeights *weigh
}
}
__C struct JiugeAWQModel *
__INFINI_C struct JiugeAWQModel *
createJiugeAWQModel(const JiugeAWQMeta *meta,
const ModelWeights *weights) {
JiugeAWQModel *model = new JiugeAWQModel(meta, weights);
return model;
}
__C void destroyJiugeAWQModel(struct JiugeAWQModel *model) {
__INFINI_C void destroyJiugeAWQModel(struct JiugeAWQModel *model) {
auto ndev = model->dev_resources.size();
for (size_t idev = 0; idev < ndev; idev++) {
......
......@@ -118,7 +118,7 @@ JiugeAWQWeights::JiugeAWQWeights(
#undef REGISTER_LAYER_QUANT_WEIGHT
}
__C struct ModelWeights *
__INFINI_C struct ModelWeights *
createJiugeAWQWeights(const JiugeAWQMeta *meta,
infiniDevice_t device,
int ndev,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment