Commit b1f6af34 authored by PanZezhong's avatar PanZezhong
Browse files

issue/263 fix T2-1-4

parent d1f29df0
...@@ -26,7 +26,7 @@ typedef struct { ...@@ -26,7 +26,7 @@ typedef struct {
qwen3vl_load_layer_fn load_attn_k_norm; qwen3vl_load_layer_fn load_attn_k_norm;
qwen3vl_load_layer_fn load_attn_qkv_proj; qwen3vl_load_layer_fn load_attn_qkv_proj;
qwen3vl_load_layer_fn load_attn_o_proj; qwen3vl_load_layer_fn load_attn_o_proj;
// MLP // MLP
qwen3vl_load_layer_fn load_mlp_norm; qwen3vl_load_layer_fn load_mlp_norm;
qwen3vl_load_layer_fn load_mlp_gate_up; qwen3vl_load_layer_fn load_mlp_gate_up;
...@@ -46,19 +46,19 @@ typedef struct { ...@@ -46,19 +46,19 @@ typedef struct {
qwen3vl_load_layer_fn load_attn_qkv_weight; qwen3vl_load_layer_fn load_attn_qkv_weight;
qwen3vl_load_layer_fn load_attn_qkv_bias; qwen3vl_load_layer_fn load_attn_qkv_bias;
//block mlp // block mlp
qwen3vl_load_layer_fn load_mlp_linear_fc1_weight; qwen3vl_load_layer_fn load_mlp_linear_fc1_weight;
qwen3vl_load_layer_fn load_mlp_linear_fc1_bias; qwen3vl_load_layer_fn load_mlp_linear_fc1_bias;
qwen3vl_load_layer_fn load_mlp_linear_fc2_weight; qwen3vl_load_layer_fn load_mlp_linear_fc2_weight;
qwen3vl_load_layer_fn load_mlp_linear_fc2_bias; qwen3vl_load_layer_fn load_mlp_linear_fc2_bias;
//block norm // block norm
qwen3vl_load_layer_fn load_norm1_weight; qwen3vl_load_layer_fn load_norm1_weight;
qwen3vl_load_layer_fn load_norm1_bias; qwen3vl_load_layer_fn load_norm1_bias;
qwen3vl_load_layer_fn load_norm2_weight; qwen3vl_load_layer_fn load_norm2_weight;
qwen3vl_load_layer_fn load_norm2_bias; qwen3vl_load_layer_fn load_norm2_bias;
//deepstack_merger // deepstack_merger
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight; qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight;
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias; qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias;
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight; qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight;
...@@ -66,7 +66,7 @@ typedef struct { ...@@ -66,7 +66,7 @@ typedef struct {
qwen3vl_load_layer_fn load_deepstack_merger_norm_weight; qwen3vl_load_layer_fn load_deepstack_merger_norm_weight;
qwen3vl_load_layer_fn load_deepstack_merger_norm_bias; qwen3vl_load_layer_fn load_deepstack_merger_norm_bias;
//merger // merger
qwen3vl_load_global_fn load_merger_linear_fc1_weight; qwen3vl_load_global_fn load_merger_linear_fc1_weight;
qwen3vl_load_global_fn load_merger_linear_fc1_bias; qwen3vl_load_global_fn load_merger_linear_fc1_bias;
qwen3vl_load_global_fn load_merger_linear_fc2_weight; qwen3vl_load_global_fn load_merger_linear_fc2_weight;
...@@ -76,7 +76,7 @@ typedef struct { ...@@ -76,7 +76,7 @@ typedef struct {
} Qwen3vlVisWeightLoader; } Qwen3vlVisWeightLoader;
typedef struct { typedef struct {
Qwen3vlLangWeightLoader lang_loader; Qwen3vlLangWeightLoader lang_loader;
Qwen3vlVisWeightLoader vis_loader; Qwen3vlVisWeightLoader vis_loader;
} Qwen3vlWeightLoader; } Qwen3vlWeightLoader;
...@@ -116,7 +116,7 @@ typedef struct { ...@@ -116,7 +116,7 @@ typedef struct {
} Qwen3vlVisMeta; } Qwen3vlVisMeta;
typedef struct { typedef struct {
infiniDtype_t dtype; //INFINI_DTYPE_BF16 infiniDtype_t dtype; // INFINI_DTYPE_BF16
Qwen3vlTextMeta text_meta; Qwen3vlTextMeta text_meta;
Qwen3vlVisMeta vis_meta; Qwen3vlVisMeta vis_meta;
...@@ -132,29 +132,29 @@ typedef struct { ...@@ -132,29 +132,29 @@ typedef struct {
/// @param device 协处理器种类 /// @param device 协处理器种类
/// @param ndev 协处理器数量 /// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev /// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct Qwen3vlModel * __INFINI_C __export struct Qwen3vlModel *
createQwen3vlModel(const Qwen3vlMeta *, createQwen3vlModel(const Qwen3vlMeta *,
const Qwen3vlWeights *); const Qwen3vlWeights *);
__C Qwen3vlWeights * __INFINI_C Qwen3vlWeights *
createQwen3vlWeights(const Qwen3vlMeta *meta, createQwen3vlWeights(const Qwen3vlMeta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
const int *dev_ids, const int *dev_ids,
bool transpose_weight); bool transpose_weight);
__C __export Qwen3vlWeightLoader * __INFINI_C __export Qwen3vlWeightLoader *
createQwen3vlWeightLoader(); createQwen3vlWeightLoader();
/// @brief 销毁模型 /// @brief 销毁模型
__C __export void destroyQwen3vlModel(struct Qwen3vlModel *); __INFINI_C __export void destroyQwen3vlModel(struct Qwen3vlModel *);
__C __export struct Qwen3vlCache * __INFINI_C __export struct Qwen3vlCache *
createQwen3vlCache(const struct Qwen3vlModel *); createQwen3vlCache(const struct Qwen3vlModel *);
__C __export void __INFINI_C __export void
dropQwen3vlCache(const struct Qwen3vlModel *, dropQwen3vlCache(const struct Qwen3vlModel *,
struct Qwen3vlCache *); struct Qwen3vlCache *);
/// @brief 批次推理一轮,并采样出新的 token /// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址 /// @param tokens 输入 token 地址
...@@ -167,18 +167,18 @@ dropQwen3vlCache(const struct Qwen3vlModel *, ...@@ -167,18 +167,18 @@ dropQwen3vlCache(const struct Qwen3vlModel *,
/// @param topk 采样 topk(1 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp /// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
inferBatchQwen3vl(struct Qwen3vlModel *, inferBatchQwen3vl(struct Qwen3vlModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
void *pixel_values, uint32_t total_patches, void *pixel_values, uint32_t total_patches,
uint32_t *image_grid_thw, uint32_t num_images, uint32_t *image_grid_thw, uint32_t num_images,
void *pixel_values_videos, uint32_t total_patches_videos, void *pixel_values_videos, uint32_t total_patches_videos,
uint32_t *video_grid_thw, uint32_t num_videos, uint32_t *video_grid_thw, uint32_t num_videos,
uint32_t patch_features, uint32_t patch_features,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct Qwen3vlCache **caches, struct Qwen3vlCache **caches,
const float *temperature, const uint32_t *topk, const float *topp, const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output); uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits /// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址 /// @param tokens 输入 token 地址
...@@ -188,7 +188,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *, ...@@ -188,7 +188,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *,
/// @param req_pos 每个请求的起始位置 /// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache /// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
forwardBatchQwen3vl(struct Qwen3vlModel *, forwardBatchQwen3vl(struct Qwen3vlModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
void *pixel_values, uint32_t total_patches, void *pixel_values, uint32_t total_patches,
......
...@@ -6,6 +6,7 @@ from .deepseek_v3 import ( ...@@ -6,6 +6,7 @@ from .deepseek_v3 import (
DeepSeekV3MetaCStruct, DeepSeekV3MetaCStruct,
DeepSeekV3WeightsCStruct, DeepSeekV3WeightsCStruct,
DeepSeekV3WeightLoaderCStruct, DeepSeekV3WeightLoaderCStruct,
DeepSeekV3CacheCStruct,
) )
from .qwen3vl import ( from .qwen3vl import (
Qwen3vlModel, Qwen3vlModel,
...@@ -33,6 +34,7 @@ __all__ = [ ...@@ -33,6 +34,7 @@ __all__ = [
"DeepSeekV3MetaCStruct", "DeepSeekV3MetaCStruct",
"DeepSeekV3WeightsCStruct", "DeepSeekV3WeightsCStruct",
"DeepSeekV3WeightLoaderCStruct", "DeepSeekV3WeightLoaderCStruct",
"DeepSeekV3CacheCStruct",
"Qwen3vlModel", "Qwen3vlModel",
"Qwen3vlMetaCStruct", "Qwen3vlMetaCStruct",
"TextMetaCStruct", "TextMetaCStruct",
......
...@@ -183,15 +183,15 @@ class Qwen3vlModel(BaseModel): ...@@ -183,15 +183,15 @@ class Qwen3vlModel(BaseModel):
POINTER(Qwen3vlModelCStruct), POINTER(Qwen3vlModelCStruct),
POINTER(c_uint), POINTER(c_uint),
c_uint, c_uint,
c_void_p, # pixel_values, c_void_p, # pixel_values,
c_uint, # total_patches, c_uint, # total_patches,
POINTER(c_uint), # image_grid_thw, POINTER(c_uint), # image_grid_thw,
c_uint, # num_images, c_uint, # num_images,
c_void_p, # pixel_values_videos, c_void_p, # pixel_values_videos,
c_uint, # total_patches_videos, c_uint, # total_patches_videos,
POINTER(c_uint), # video_grid_thw, POINTER(c_uint), # video_grid_thw,
c_uint, # num_videos, c_uint, # num_videos,
c_uint, # patch_features, c_uint, # patch_features,
POINTER(c_uint), POINTER(c_uint),
c_uint, c_uint,
POINTER(c_uint), POINTER(c_uint),
...@@ -206,15 +206,15 @@ class Qwen3vlModel(BaseModel): ...@@ -206,15 +206,15 @@ class Qwen3vlModel(BaseModel):
POINTER(Qwen3vlModelCStruct), POINTER(Qwen3vlModelCStruct),
POINTER(c_uint), POINTER(c_uint),
c_uint, c_uint,
c_void_p, # pixel_values, c_void_p, # pixel_values,
c_uint, # total_patches, c_uint, # total_patches,
POINTER(c_uint), # image_grid_thw, POINTER(c_uint), # image_grid_thw,
c_uint, # num_images, c_uint, # num_images,
c_void_p, # pixel_values_videos, c_void_p, # pixel_values_videos,
c_uint, # total_patches_videos, c_uint, # total_patches_videos,
POINTER(c_uint), # video_grid_thw, POINTER(c_uint), # video_grid_thw,
c_uint, # num_videos, c_uint, # num_videos,
c_uint, # patch_features, c_uint, # patch_features,
POINTER(c_uint), POINTER(c_uint),
c_uint, c_uint,
POINTER(c_uint), POINTER(c_uint),
...@@ -226,7 +226,9 @@ class Qwen3vlModel(BaseModel): ...@@ -226,7 +226,9 @@ class Qwen3vlModel(BaseModel):
return self.lib.createQwen3vlWeightLoader() return self.lib.createQwen3vlWeightLoader()
def create_weights(self, meta, device_type, ndev, dev_ids, transpose_weight): def create_weights(self, meta, device_type, ndev, dev_ids, transpose_weight):
return self.lib.createQwen3vlWeights(meta, device_type, ndev, dev_ids, transpose_weight) return self.lib.createQwen3vlWeights(
meta, device_type, ndev, dev_ids, transpose_weight
)
def create_model(self, meta, weights): def create_model(self, meta, weights):
return self.lib.createQwen3vlModel(meta, weights) return self.lib.createQwen3vlModel(meta, weights)
...@@ -324,4 +326,4 @@ class Qwen3vlModel(BaseModel): ...@@ -324,4 +326,4 @@ class Qwen3vlModel(BaseModel):
req_pos, req_pos,
caches, caches,
logits, logits,
) )
\ No newline at end of file
This diff is collapsed.
...@@ -16,7 +16,7 @@ public: ...@@ -16,7 +16,7 @@ public:
class MemoryPool : public AllocatorBase { class MemoryPool : public AllocatorBase {
public: public:
static constexpr size_t DEFAULT_ALIGNMENT = 256; static constexpr size_t DEFAULT_ALIGNMENT = 512;
explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT); explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT);
~MemoryPool(); ~MemoryPool();
......
...@@ -34,24 +34,24 @@ void InferenceContext::add(std::shared_ptr<Tensor> c, ...@@ -34,24 +34,24 @@ void InferenceContext::add(std::shared_ptr<Tensor> c,
} }
void InferenceContext::conv(std::shared_ptr<Tensor> y, void InferenceContext::conv(std::shared_ptr<Tensor> y,
std::shared_ptr<Tensor> x, std::shared_ptr<Tensor> x,
std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> w,
std::shared_ptr<Tensor> bias, std::shared_ptr<Tensor> bias,
void *pads, void *pads,
void *strides, void *strides,
void *dilations, void *dilations,
size_t n) { size_t n) {
size_t key = CacheManager::createDescriptorKey(y, x, w, bias); size_t key = CacheManager::createDescriptorKey(y, x, w, bias);
// Combine additional parameters into the key for unique identification // Combine additional parameters into the key for unique identification
hash_combine(key, std::hash<void*>()(pads)); hash_combine(key, std::hash<void *>()(pads));
hash_combine(key, std::hash<void*>()(strides)); hash_combine(key, std::hash<void *>()(strides));
hash_combine(key, std::hash<void*>()(dilations)); hash_combine(key, std::hash<void *>()(dilations));
hash_combine(key, std::hash<size_t>()(n)); hash_combine(key, std::hash<size_t>()(n));
infiniopConvDescriptor_t desc; infiniopConvDescriptor_t desc;
if (!cache_manager->getConvDescriptor(key, desc)) { if (!cache_manager->getConvDescriptor(key, desc)) {
RUN_INFINI(infiniopCreateConvDescriptor( RUN_INFINI(infiniopCreateConvDescriptor(
op_handle, &desc, y->desc(), x->desc(), w->desc(), op_handle, &desc, y->desc(), x->desc(), w->desc(),
bias ? bias->desc() : nullptr, pads, strides, dilations, n)); bias ? bias->desc() : nullptr, pads, strides, dilations, n));
cache_manager->putConvDescriptor(key, desc); cache_manager->putConvDescriptor(key, desc);
} }
...@@ -63,7 +63,7 @@ void InferenceContext::conv(std::shared_ptr<Tensor> y, ...@@ -63,7 +63,7 @@ void InferenceContext::conv(std::shared_ptr<Tensor> y,
RUN_INFINI(infiniopConv( RUN_INFINI(infiniopConv(
desc, workspace, workspace_size, desc, workspace, workspace_size,
y->data(), x->data(), w->data(), y->data(), x->data(), w->data(),
bias ? bias->data() : nullptr, stream)); bias ? bias->data() : nullptr, stream));
} }
......
...@@ -92,7 +92,7 @@ inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::share ...@@ -92,7 +92,7 @@ inline void add(std::shared_ptr<Tensor> c, std::shared_ptr<Tensor> a, std::share
} }
inline void conv(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x, std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> bias, inline void conv(std::shared_ptr<Tensor> y, std::shared_ptr<Tensor> x, std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> bias,
void *pads, void *strides, void *dilations, size_t n) { void *pads, void *strides, void *dilations, size_t n) {
getInferenceContext().conv(y, x, w, bias, pads, strides, dilations, n); getInferenceContext().conv(y, x, w, bias, pads, strides, dilations, n);
} }
......
This diff is collapsed.
#include "qwen3vl_impl.hpp" #include "qwen3vl_impl.hpp"
__C struct Qwen3vlCache * __INFINI_C struct Qwen3vlCache *
createQwen3vlCache(const struct Qwen3vlModel *model) { createQwen3vlCache(const struct Qwen3vlModel *model) {
Qwen3vlCache *cache = new Qwen3vlCache(); Qwen3vlCache *cache = new Qwen3vlCache();
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
...@@ -27,9 +27,9 @@ createQwen3vlCache(const struct Qwen3vlModel *model) { ...@@ -27,9 +27,9 @@ createQwen3vlCache(const struct Qwen3vlModel *model) {
//////还有visual deepstack需要cache? //////还有visual deepstack需要cache?
__C void __INFINI_C void
dropQwen3vlCache(const struct Qwen3vlModel *model, dropQwen3vlCache(const struct Qwen3vlModel *model,
struct Qwen3vlCache *cache) { struct Qwen3vlCache *cache) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
auto nlayer = model->meta.text_meta.num_hidden_layers; auto nlayer = model->meta.text_meta.num_hidden_layers;
for (size_t idev = 0; idev < ndev; idev++) { for (size_t idev = 0; idev < ndev; idev++) {
...@@ -40,4 +40,4 @@ dropQwen3vlCache(const struct Qwen3vlModel *model, ...@@ -40,4 +40,4 @@ dropQwen3vlCache(const struct Qwen3vlModel *model,
} }
} }
delete cache; delete cache;
} }
\ No newline at end of file
...@@ -45,7 +45,6 @@ struct MergerWeight { ...@@ -45,7 +45,6 @@ struct MergerWeight {
std::shared_ptr<Tensor> norm_weight, norm_bias; std::shared_ptr<Tensor> norm_weight, norm_bias;
}; };
struct Qwen3vlVisualEncoderWeight { struct Qwen3vlVisualEncoderWeight {
std::shared_ptr<Tensor> patch_embed_weight, patch_embed_bias, pos_embed_weight; std::shared_ptr<Tensor> patch_embed_weight, patch_embed_bias, pos_embed_weight;
std::vector<Qwen3vlVisBlockWeight> blocks; std::vector<Qwen3vlVisBlockWeight> blocks;
...@@ -53,9 +52,8 @@ struct Qwen3vlVisualEncoderWeight { ...@@ -53,9 +52,8 @@ struct Qwen3vlVisualEncoderWeight {
std::shared_ptr<MergerWeight> merger; std::shared_ptr<MergerWeight> merger;
}; };
struct Qwen3vlDeviceWeights { struct Qwen3vlDeviceWeights {
std::shared_ptr<Tensor> sin_table,cos_table; std::shared_ptr<Tensor> sin_table, cos_table;
std::shared_ptr<Qwen3vlLanguageModelWeight> w_lang; std::shared_ptr<Qwen3vlLanguageModelWeight> w_lang;
std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis; std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis;
infiniDevice_t device; infiniDevice_t device;
...@@ -69,10 +67,10 @@ struct Qwen3vlWeights { ...@@ -69,10 +67,10 @@ struct Qwen3vlWeights {
std::vector<std::shared_ptr<Qwen3vlDeviceWeights>> device_weights; std::vector<std::shared_ptr<Qwen3vlDeviceWeights>> device_weights;
Qwen3vlWeights(const Qwen3vlMeta *meta, Qwen3vlWeights(const Qwen3vlMeta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
const int *dev_ids, const int *dev_ids,
bool transpose_weight); bool transpose_weight);
}; };
struct Qwen3vlDeviceResource { struct Qwen3vlDeviceResource {
...@@ -137,7 +135,7 @@ struct Qwen3vlModel { ...@@ -137,7 +135,7 @@ struct Qwen3vlModel {
}; };
struct Qwen3vlCache { struct Qwen3vlCache {
std::vector<std::vector<std::shared_ptr<Tensor>>> k_rot, v; std::vector<std::vector<std::shared_ptr<Tensor>>> k_rot, v;
}; };
#endif #endif
\ No newline at end of file
...@@ -23,7 +23,7 @@ inline std::shared_ptr<Tensor> getOutEmbd( ...@@ -23,7 +23,7 @@ inline std::shared_ptr<Tensor> getOutEmbd(
} }
inline void getLayerWeight( inline void getLayerWeight(
const Qwen3vlMeta *meta, Qwen3vlLayerWeight& layer, int ndev) { const Qwen3vlMeta *meta, Qwen3vlLayerWeight &layer, int ndev) {
auto nkvh = meta->text_meta.num_key_value_heads; auto nkvh = meta->text_meta.num_key_value_heads;
auto nh = meta->text_meta.num_attention_heads; auto nh = meta->text_meta.num_attention_heads;
auto dh = meta->text_meta.head_dim; auto dh = meta->text_meta.head_dim;
...@@ -39,7 +39,7 @@ inline void getLayerWeight( ...@@ -39,7 +39,7 @@ inline void getLayerWeight(
layer.attn_qkv_proj = Tensor::weight(nullptr, meta->dtype, qkv_proj_shape); layer.attn_qkv_proj = Tensor::weight(nullptr, meta->dtype, qkv_proj_shape);
auto o_proj_shape = std::vector<size_t>({d, nh / ndev * dh}); auto o_proj_shape = std::vector<size_t>({d, nh / ndev * dh});
layer.attn_o_proj = Tensor::weight(nullptr, meta->dtype, o_proj_shape); layer.attn_o_proj = Tensor::weight(nullptr, meta->dtype, o_proj_shape);
layer.mlp_norm = Tensor::weight(nullptr, meta->dtype, dh_shape); layer.mlp_norm = Tensor::weight(nullptr, meta->dtype, dh_shape);
auto up_shape = std::vector<size_t>({2 * di / ndev, d}); auto up_shape = std::vector<size_t>({2 * di / ndev, d});
layer.mlp_gate_up = Tensor::weight(nullptr, meta->dtype, up_shape); layer.mlp_gate_up = Tensor::weight(nullptr, meta->dtype, up_shape);
...@@ -47,11 +47,10 @@ inline void getLayerWeight( ...@@ -47,11 +47,10 @@ inline void getLayerWeight(
layer.mlp_down = Tensor::weight(nullptr, meta->dtype, down_shape); layer.mlp_down = Tensor::weight(nullptr, meta->dtype, down_shape);
} }
inline void getVisualWeight( inline void getVisualWeight(
const Qwen3vlMeta *meta, std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis) { const Qwen3vlMeta *meta, std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis) {
Qwen3vlVisMeta vis_meta = meta->vis_meta; Qwen3vlVisMeta vis_meta = meta->vis_meta;
auto patch_embed_shape = std::vector<size_t>({vis_meta.hidden_size , vis_meta.in_channels, vis_meta.temporal_patch_size, vis_meta.patch_size, vis_meta.patch_size}); auto patch_embed_shape = std::vector<size_t>({vis_meta.hidden_size, vis_meta.in_channels, vis_meta.temporal_patch_size, vis_meta.patch_size, vis_meta.patch_size});
w_vis->patch_embed_weight = Tensor::weight(nullptr, meta->dtype, patch_embed_shape); w_vis->patch_embed_weight = Tensor::weight(nullptr, meta->dtype, patch_embed_shape);
w_vis->patch_embed_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->patch_embed_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->pos_embed_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.num_position_embeddings, vis_meta.hidden_size}); w_vis->pos_embed_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.num_position_embeddings, vis_meta.hidden_size});
...@@ -63,11 +62,11 @@ inline void getVisualWeight( ...@@ -63,11 +62,11 @@ inline void getVisualWeight(
w_vis->merger->norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->merger->norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->merger->norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->merger->norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->blocks = std::vector<Qwen3vlVisBlockWeight>(vis_meta.depth); w_vis->blocks = std::vector<Qwen3vlVisBlockWeight>(vis_meta.depth);
for (size_t i = 0; i < vis_meta.depth; i++) { for (size_t i = 0; i < vis_meta.depth; i++) {
w_vis->blocks[i].attn_proj_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size,vis_meta.hidden_size}); w_vis->blocks[i].attn_proj_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.hidden_size});
w_vis->blocks[i].attn_proj_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->blocks[i].attn_proj_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->blocks[i].attn_qkv_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels*vis_meta.hidden_size,vis_meta.hidden_size}); w_vis->blocks[i].attn_qkv_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels * vis_meta.hidden_size, vis_meta.hidden_size});
w_vis->blocks[i].attn_qkv_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels*vis_meta.hidden_size}); w_vis->blocks[i].attn_qkv_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels * vis_meta.hidden_size});
w_vis->blocks[i].mlp_linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.hidden_size}); w_vis->blocks[i].mlp_linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.hidden_size});
w_vis->blocks[i].mlp_linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->blocks[i].mlp_linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
w_vis->blocks[i].mlp_linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.intermediate_size}); w_vis->blocks[i].mlp_linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.intermediate_size});
...@@ -78,18 +77,16 @@ inline void getVisualWeight( ...@@ -78,18 +77,16 @@ inline void getVisualWeight(
w_vis->blocks[i].norm2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->blocks[i].norm2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
} }
w_vis->deepstack_mergers = std::vector<DeepstackMergerWeight>(3); w_vis->deepstack_mergers = std::vector<DeepstackMergerWeight>(3);
for (size_t i = 0; i < 3; i++){ for (size_t i = 0; i < 3; i++) {
w_vis->deepstack_mergers[i].linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size,vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size,vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size, vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size}); w_vis->deepstack_mergers[i].linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size});
w_vis->deepstack_mergers[i].norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
} }
} }
inline std::shared_ptr<Tensor> getSinTable(const Qwen3vlMeta *meta) { inline std::shared_ptr<Tensor> getSinTable(const Qwen3vlMeta *meta) {
auto half_dh = meta->text_meta.head_dim / 2; auto half_dh = meta->text_meta.head_dim / 2;
auto unit = dsize(meta->dtype); auto unit = dsize(meta->dtype);
...@@ -172,7 +169,6 @@ Qwen3vlWeights::Qwen3vlWeights( ...@@ -172,7 +169,6 @@ Qwen3vlWeights::Qwen3vlWeights(
} }
getVisualWeight(meta, device_weights[dev]->w_vis); getVisualWeight(meta, device_weights[dev]->w_vis);
} }
} }
...@@ -201,8 +197,8 @@ void load_output_embd(Qwen3vlWeights *weights, void *cpu_ptr) { ...@@ -201,8 +197,8 @@ void load_output_embd(Qwen3vlWeights *weights, void *cpu_ptr) {
auto weight = weights->device_weights[dev]; auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->out_embd->load(cpu_ptr, weight->load_stream); weight->w_lang->out_embd->load(cpu_ptr, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->out_embd->permute({1,0}); //[d,voc] weight->w_lang->out_embd->permute({1, 0}); //[d,voc]
} }
} }
} }
...@@ -239,9 +235,8 @@ void load_attn_qkv_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -239,9 +235,8 @@ void load_attn_qkv_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(weights->meta->dtype); size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].attn_qkv_proj->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].attn_qkv_proj->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].attn_qkv_proj = weight->w_lang->layers[layer].attn_qkv_proj = weight->w_lang->layers[layer].attn_qkv_proj->permute({1, 0}); //[d, (nh+2*nkvh)*dh]
weight->w_lang->layers[layer].attn_qkv_proj->permute({1,0}); //[d, (nh+2*nkvh)*dh]
} }
} }
} }
...@@ -267,9 +262,8 @@ void load_attn_o_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -267,9 +262,8 @@ void load_attn_o_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * d * (nh / ndev * dh) * dsize(weights->meta->dtype); size_t offset = idev * d * (nh / ndev * dh) * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].attn_o_proj->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].attn_o_proj->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].attn_o_proj = weight->w_lang->layers[layer].attn_o_proj = weight->w_lang->layers[layer].attn_o_proj->permute({1, 0}); //[nh/ndev*dh, d]
weight->w_lang->layers[layer].attn_o_proj->permute({1,0}); //[nh/ndev*dh, d]
} }
} }
} }
...@@ -295,9 +289,8 @@ void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -295,9 +289,8 @@ void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * (2 * di / ndev) * d * dsize(weights->meta->dtype); size_t offset = idev * (2 * di / ndev) * d * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].mlp_gate_up->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].mlp_gate_up->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].mlp_gate_up = weight->w_lang->layers[layer].mlp_gate_up = weight->w_lang->layers[layer].mlp_gate_up->permute({1, 0}); //[d, 2*di/ndev]
weight->w_lang->layers[layer].mlp_gate_up->permute({1,0}); //[d, 2*di/ndev]
} }
} }
} }
...@@ -313,10 +306,9 @@ void load_mlp_down(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -313,10 +306,9 @@ void load_mlp_down(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * d * (di / ndev) * dsize(weights->meta->dtype); size_t offset = idev * d * (di / ndev) * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].mlp_down->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].mlp_down->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].mlp_down = weight->w_lang->layers[layer].mlp_down = weight->w_lang->layers[layer].mlp_down->permute({1, 0}); //[di/ndev, d]
weight->w_lang->layers[layer].mlp_down->permute({1,0}); //[di/ndev, d] }
}
} }
} }
...@@ -569,7 +561,6 @@ void load_merger_norm_bias(Qwen3vlWeights *weights, void *cpu_ptr) { ...@@ -569,7 +561,6 @@ void load_merger_norm_bias(Qwen3vlWeights *weights, void *cpu_ptr) {
} }
} }
static Qwen3vlWeightLoader weight_loader = { static Qwen3vlWeightLoader weight_loader = {
// Language model loaders // Language model loaders
.lang_loader = { .lang_loader = {
...@@ -614,15 +605,14 @@ static Qwen3vlWeightLoader weight_loader = { ...@@ -614,15 +605,14 @@ static Qwen3vlWeightLoader weight_loader = {
.load_merger_linear_fc2_bias = load_merger_linear_fc2_bias, .load_merger_linear_fc2_bias = load_merger_linear_fc2_bias,
.load_merger_norm_weight = load_merger_norm_weight, .load_merger_norm_weight = load_merger_norm_weight,
.load_merger_norm_bias = load_merger_norm_bias, .load_merger_norm_bias = load_merger_norm_bias,
} }};
};
__C Qwen3vlWeights * __INFINI_C Qwen3vlWeights *
createQwen3vlWeights(const Qwen3vlMeta *meta, createQwen3vlWeights(const Qwen3vlMeta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
const int *dev_ids, const int *dev_ids,
bool transpose_weight) { bool transpose_weight) {
printf("=== C++ createQwen3vlWeights ===\n"); printf("=== C++ createQwen3vlWeights ===\n");
printf("sizeof(Qwen3vlTextMeta): %zu\n", sizeof(Qwen3vlTextMeta)); printf("sizeof(Qwen3vlTextMeta): %zu\n", sizeof(Qwen3vlTextMeta));
...@@ -640,7 +630,7 @@ createQwen3vlWeights(const Qwen3vlMeta *meta, ...@@ -640,7 +630,7 @@ createQwen3vlWeights(const Qwen3vlMeta *meta,
return weights; return weights;
}; };
__C Qwen3vlWeightLoader * __INFINI_C Qwen3vlWeightLoader *
createQwen3vlWeightLoader() { createQwen3vlWeightLoader() {
return &weight_loader; return &weight_loader;
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment