Commit b1f6af34 authored by PanZezhong's avatar PanZezhong
Browse files

issue/263 fix T2-1-4

parent d1f29df0
...@@ -46,19 +46,19 @@ typedef struct { ...@@ -46,19 +46,19 @@ typedef struct {
qwen3vl_load_layer_fn load_attn_qkv_weight; qwen3vl_load_layer_fn load_attn_qkv_weight;
qwen3vl_load_layer_fn load_attn_qkv_bias; qwen3vl_load_layer_fn load_attn_qkv_bias;
//block mlp // block mlp
qwen3vl_load_layer_fn load_mlp_linear_fc1_weight; qwen3vl_load_layer_fn load_mlp_linear_fc1_weight;
qwen3vl_load_layer_fn load_mlp_linear_fc1_bias; qwen3vl_load_layer_fn load_mlp_linear_fc1_bias;
qwen3vl_load_layer_fn load_mlp_linear_fc2_weight; qwen3vl_load_layer_fn load_mlp_linear_fc2_weight;
qwen3vl_load_layer_fn load_mlp_linear_fc2_bias; qwen3vl_load_layer_fn load_mlp_linear_fc2_bias;
//block norm // block norm
qwen3vl_load_layer_fn load_norm1_weight; qwen3vl_load_layer_fn load_norm1_weight;
qwen3vl_load_layer_fn load_norm1_bias; qwen3vl_load_layer_fn load_norm1_bias;
qwen3vl_load_layer_fn load_norm2_weight; qwen3vl_load_layer_fn load_norm2_weight;
qwen3vl_load_layer_fn load_norm2_bias; qwen3vl_load_layer_fn load_norm2_bias;
//deepstack_merger // deepstack_merger
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight; qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_weight;
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias; qwen3vl_load_layer_fn load_deepstack_merger_linear_fc1_bias;
qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight; qwen3vl_load_layer_fn load_deepstack_merger_linear_fc2_weight;
...@@ -66,7 +66,7 @@ typedef struct { ...@@ -66,7 +66,7 @@ typedef struct {
qwen3vl_load_layer_fn load_deepstack_merger_norm_weight; qwen3vl_load_layer_fn load_deepstack_merger_norm_weight;
qwen3vl_load_layer_fn load_deepstack_merger_norm_bias; qwen3vl_load_layer_fn load_deepstack_merger_norm_bias;
//merger // merger
qwen3vl_load_global_fn load_merger_linear_fc1_weight; qwen3vl_load_global_fn load_merger_linear_fc1_weight;
qwen3vl_load_global_fn load_merger_linear_fc1_bias; qwen3vl_load_global_fn load_merger_linear_fc1_bias;
qwen3vl_load_global_fn load_merger_linear_fc2_weight; qwen3vl_load_global_fn load_merger_linear_fc2_weight;
...@@ -116,7 +116,7 @@ typedef struct { ...@@ -116,7 +116,7 @@ typedef struct {
} Qwen3vlVisMeta; } Qwen3vlVisMeta;
typedef struct { typedef struct {
infiniDtype_t dtype; //INFINI_DTYPE_BF16 infiniDtype_t dtype; // INFINI_DTYPE_BF16
Qwen3vlTextMeta text_meta; Qwen3vlTextMeta text_meta;
Qwen3vlVisMeta vis_meta; Qwen3vlVisMeta vis_meta;
...@@ -132,27 +132,27 @@ typedef struct { ...@@ -132,27 +132,27 @@ typedef struct {
/// @param device 协处理器种类 /// @param device 协处理器种类
/// @param ndev 协处理器数量 /// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev /// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct Qwen3vlModel * __INFINI_C __export struct Qwen3vlModel *
createQwen3vlModel(const Qwen3vlMeta *, createQwen3vlModel(const Qwen3vlMeta *,
const Qwen3vlWeights *); const Qwen3vlWeights *);
__C Qwen3vlWeights * __INFINI_C Qwen3vlWeights *
createQwen3vlWeights(const Qwen3vlMeta *meta, createQwen3vlWeights(const Qwen3vlMeta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
const int *dev_ids, const int *dev_ids,
bool transpose_weight); bool transpose_weight);
__C __export Qwen3vlWeightLoader * __INFINI_C __export Qwen3vlWeightLoader *
createQwen3vlWeightLoader(); createQwen3vlWeightLoader();
/// @brief 销毁模型 /// @brief 销毁模型
__C __export void destroyQwen3vlModel(struct Qwen3vlModel *); __INFINI_C __export void destroyQwen3vlModel(struct Qwen3vlModel *);
__C __export struct Qwen3vlCache * __INFINI_C __export struct Qwen3vlCache *
createQwen3vlCache(const struct Qwen3vlModel *); createQwen3vlCache(const struct Qwen3vlModel *);
__C __export void __INFINI_C __export void
dropQwen3vlCache(const struct Qwen3vlModel *, dropQwen3vlCache(const struct Qwen3vlModel *,
struct Qwen3vlCache *); struct Qwen3vlCache *);
...@@ -167,7 +167,7 @@ dropQwen3vlCache(const struct Qwen3vlModel *, ...@@ -167,7 +167,7 @@ dropQwen3vlCache(const struct Qwen3vlModel *,
/// @param topk 采样 topk(1 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp /// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
inferBatchQwen3vl(struct Qwen3vlModel *, inferBatchQwen3vl(struct Qwen3vlModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
void *pixel_values, uint32_t total_patches, void *pixel_values, uint32_t total_patches,
...@@ -188,7 +188,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *, ...@@ -188,7 +188,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *,
/// @param req_pos 每个请求的起始位置 /// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache /// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void __INFINI_C __export void
forwardBatchQwen3vl(struct Qwen3vlModel *, forwardBatchQwen3vl(struct Qwen3vlModel *,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
void *pixel_values, uint32_t total_patches, void *pixel_values, uint32_t total_patches,
......
...@@ -6,6 +6,7 @@ from .deepseek_v3 import ( ...@@ -6,6 +6,7 @@ from .deepseek_v3 import (
DeepSeekV3MetaCStruct, DeepSeekV3MetaCStruct,
DeepSeekV3WeightsCStruct, DeepSeekV3WeightsCStruct,
DeepSeekV3WeightLoaderCStruct, DeepSeekV3WeightLoaderCStruct,
DeepSeekV3CacheCStruct,
) )
from .qwen3vl import ( from .qwen3vl import (
Qwen3vlModel, Qwen3vlModel,
...@@ -33,6 +34,7 @@ __all__ = [ ...@@ -33,6 +34,7 @@ __all__ = [
"DeepSeekV3MetaCStruct", "DeepSeekV3MetaCStruct",
"DeepSeekV3WeightsCStruct", "DeepSeekV3WeightsCStruct",
"DeepSeekV3WeightLoaderCStruct", "DeepSeekV3WeightLoaderCStruct",
"DeepSeekV3CacheCStruct",
"Qwen3vlModel", "Qwen3vlModel",
"Qwen3vlMetaCStruct", "Qwen3vlMetaCStruct",
"TextMetaCStruct", "TextMetaCStruct",
......
...@@ -226,7 +226,9 @@ class Qwen3vlModel(BaseModel): ...@@ -226,7 +226,9 @@ class Qwen3vlModel(BaseModel):
return self.lib.createQwen3vlWeightLoader() return self.lib.createQwen3vlWeightLoader()
def create_weights(self, meta, device_type, ndev, dev_ids, transpose_weight): def create_weights(self, meta, device_type, ndev, dev_ids, transpose_weight):
return self.lib.createQwen3vlWeights(meta, device_type, ndev, dev_ids, transpose_weight) return self.lib.createQwen3vlWeights(
meta, device_type, ndev, dev_ids, transpose_weight
)
def create_model(self, meta, weights): def create_model(self, meta, weights):
return self.lib.createQwen3vlModel(meta, weights) return self.lib.createQwen3vlModel(meta, weights)
......
...@@ -25,6 +25,7 @@ import json ...@@ -25,6 +25,7 @@ import json
import math import math
import torch import torch
import transformers import transformers
torch.set_default_device("cpu") torch.set_default_device("cpu")
...@@ -71,116 +72,155 @@ class Qwen3vlLangWeightsNaming: ...@@ -71,116 +72,155 @@ class Qwen3vlLangWeightsNaming:
def mlp_up(self, i): def mlp_up(self, i):
return f"model.language_model.layers.{i}.mlp.up_proj.weight" return f"model.language_model.layers.{i}.mlp.up_proj.weight"
class Qwen3vlVisWeightsNaming: class Qwen3vlVisWeightsNaming:
def patch_embed_weight(self): def patch_embed_weight(self):
return "model.visual.patch_embed.proj.weight" return "model.visual.patch_embed.proj.weight"
def patch_embed_bias(self): def patch_embed_bias(self):
return "model.visual.patch_embed.proj.bias" return "model.visual.patch_embed.proj.bias"
def pos_embed_weight(self): def pos_embed_weight(self):
return "model.visual.pos_embed.weight" return "model.visual.pos_embed.weight"
def attn_proj_weight(self,i):
def attn_proj_weight(self, i):
return f"model.visual.blocks.{i}.attn.proj.weight" return f"model.visual.blocks.{i}.attn.proj.weight"
def attn_proj_bias(self,i):
def attn_proj_bias(self, i):
return f"model.visual.blocks.{i}.attn.proj.bias" return f"model.visual.blocks.{i}.attn.proj.bias"
def attn_qkv_weight(self,i):
def attn_qkv_weight(self, i):
return f"model.visual.blocks.{i}.attn.qkv.weight" return f"model.visual.blocks.{i}.attn.qkv.weight"
def attn_qkv_bias(self,i):
def attn_qkv_bias(self, i):
return f"model.visual.blocks.{i}.attn.qkv.bias" return f"model.visual.blocks.{i}.attn.qkv.bias"
def mlp_linear_fc1_weight(self,i):
def mlp_linear_fc1_weight(self, i):
return f"model.visual.blocks.{i}.mlp.linear_fc1.weight" return f"model.visual.blocks.{i}.mlp.linear_fc1.weight"
def mlp_linear_fc1_bias(self,i):
def mlp_linear_fc1_bias(self, i):
return f"model.visual.blocks.{i}.mlp.linear_fc1.bias" return f"model.visual.blocks.{i}.mlp.linear_fc1.bias"
def mlp_linear_fc2_weight(self,i):
def mlp_linear_fc2_weight(self, i):
return f"model.visual.blocks.{i}.mlp.linear_fc2.weight" return f"model.visual.blocks.{i}.mlp.linear_fc2.weight"
def mlp_linear_fc2_bias(self,i):
def mlp_linear_fc2_bias(self, i):
return f"model.visual.blocks.{i}.mlp.linear_fc2.bias" return f"model.visual.blocks.{i}.mlp.linear_fc2.bias"
def norm1_weight(self,i):
def norm1_weight(self, i):
return f"model.visual.blocks.{i}.norm1.weight" return f"model.visual.blocks.{i}.norm1.weight"
def norm1_bias(self,i):
def norm1_bias(self, i):
return f"model.visual.blocks.{i}.norm1.bias" return f"model.visual.blocks.{i}.norm1.bias"
def norm2_weight(self,i):
def norm2_weight(self, i):
return f"model.visual.blocks.{i}.norm2.weight" return f"model.visual.blocks.{i}.norm2.weight"
def norm2_bias(self,i):
def norm2_bias(self, i):
return f"model.visual.blocks.{i}.norm2.bias" return f"model.visual.blocks.{i}.norm2.bias"
def deepstack_merger_linear_fc1_weight(self,i):
def deepstack_merger_linear_fc1_weight(self, i):
return f"model.visual.deepstack_merger_list.{i}.linear_fc1.weight" return f"model.visual.deepstack_merger_list.{i}.linear_fc1.weight"
def deepstack_merger_linear_fc1_bias(self,i):
def deepstack_merger_linear_fc1_bias(self, i):
return f"model.visual.deepstack_merger_list.{i}.linear_fc1.bias" return f"model.visual.deepstack_merger_list.{i}.linear_fc1.bias"
def deepstack_merger_linear_fc2_weight(self,i):
def deepstack_merger_linear_fc2_weight(self, i):
return f"model.visual.deepstack_merger_list.{i}.linear_fc2.weight" return f"model.visual.deepstack_merger_list.{i}.linear_fc2.weight"
def deepstack_merger_linear_fc2_bias(self,i):
def deepstack_merger_linear_fc2_bias(self, i):
return f"model.visual.deepstack_merger_list.{i}.linear_fc2.bias" return f"model.visual.deepstack_merger_list.{i}.linear_fc2.bias"
def deepstack_merger_norm_weight(self,i):
def deepstack_merger_norm_weight(self, i):
return f"model.visual.deepstack_merger_list.{i}.norm.weight" return f"model.visual.deepstack_merger_list.{i}.norm.weight"
def deepstack_merger_norm_bias(self,i):
def deepstack_merger_norm_bias(self, i):
return f"model.visual.deepstack_merger_list.{i}.norm.bias" return f"model.visual.deepstack_merger_list.{i}.norm.bias"
def merger_linear_fc1_weight(self): def merger_linear_fc1_weight(self):
return "model.visual.merger.linear_fc1.weight" return "model.visual.merger.linear_fc1.weight"
def merger_linear_fc1_bias(self): def merger_linear_fc1_bias(self):
return "model.visual.merger.linear_fc1.bias" return "model.visual.merger.linear_fc1.bias"
def merger_linear_fc2_weight(self): def merger_linear_fc2_weight(self):
return "model.visual.merger.linear_fc2.weight" return "model.visual.merger.linear_fc2.weight"
def merger_linear_fc2_bias(self): def merger_linear_fc2_bias(self):
return "model.visual.merger.linear_fc2.bias" return "model.visual.merger.linear_fc2.bias"
def merger_norm_weight(self): def merger_norm_weight(self):
return "model.visual.merger.norm.weight" return "model.visual.merger.norm.weight"
def merger_norm_bias(self): def merger_norm_bias(self):
return "model.visual.merger.norm.bias" return "model.visual.merger.norm.bias"
class Qwen3vlMeta(Qwen3vlMetaCStruct): class Qwen3vlMeta(Qwen3vlMetaCStruct):
def __init__(self, config, max_tokens=None): def __init__(self, config, max_tokens=None):
if config["text_config"]["dtype"] == "float16":
if config['text_config']['dtype'] == 'float16':
dt_ = DataType.INFINI_DTYPE_F16 dt_ = DataType.INFINI_DTYPE_F16
self.torch_dtype = torch.float16 self.torch_dtype = torch.float16
elif config['text_config']['dtype'] == 'float32': elif config["text_config"]["dtype"] == "float32":
dt_ = DataType.INFINI_DTYPE_F32 dt_ = DataType.INFINI_DTYPE_F32
self.torch_dtype = torch.float32 self.torch_dtype = torch.float32
elif config['text_config']['dtype'] == 'bfloat16': elif config["text_config"]["dtype"] == "bfloat16":
dt_ = DataType.INFINI_DTYPE_BF16 dt_ = DataType.INFINI_DTYPE_BF16
self.torch_dtype = torch.bfloat16 self.torch_dtype = torch.bfloat16
else: else:
raise ValueError(f"Unsupported text dtype: {config['text_config']['dtype']}") raise ValueError(
f"Unsupported text dtype: {config['text_config']['dtype']}"
)
super().__init__( super().__init__(
dtype = dt_, dtype=dt_,
image_token_id = config['image_token_id'], image_token_id=config["image_token_id"],
video_token_id = config['video_token_id'], video_token_id=config["video_token_id"],
vision_end_token_id = config['vision_end_token_id'], vision_end_token_id=config["vision_end_token_id"],
vision_start_token_id = config['vision_start_token_id'], vision_start_token_id=config["vision_start_token_id"],
text_meta = TextMetaCStruct( text_meta=TextMetaCStruct(
bos_token_id = config['text_config']['bos_token_id'], bos_token_id=config["text_config"]["bos_token_id"],
eos_token_id = config['text_config']['eos_token_id'], eos_token_id=config["text_config"]["eos_token_id"],
head_dim = config['text_config']['head_dim'], head_dim=config["text_config"]["head_dim"],
hidden_size = config['text_config']['hidden_size'], hidden_size=config["text_config"]["hidden_size"],
initializer_range = config['text_config']['initializer_range'], initializer_range=config["text_config"]["initializer_range"],
intermediate_size = config['text_config']['intermediate_size'], intermediate_size=config["text_config"]["intermediate_size"],
max_tokens = (config['text_config']['max_position_embeddings'] if max_tokens is None else max_tokens), max_tokens=(
num_attention_heads = config['text_config']['num_attention_heads'], config["text_config"]["max_position_embeddings"]
num_hidden_layers = config['text_config']['num_hidden_layers'], if max_tokens is None
num_key_value_heads = config['text_config']['num_key_value_heads'], else max_tokens
rms_norm_eps = config['text_config']['rms_norm_eps'], ),
mrope_section = (ctypes.c_ulong * 3)(*config['text_config']['rope_scaling']['mrope_section']), num_attention_heads=config["text_config"]["num_attention_heads"],
rope_theta = config['text_config']['rope_theta'], num_hidden_layers=config["text_config"]["num_hidden_layers"],
vocab_size = config['text_config']['vocab_size'], num_key_value_heads=config["text_config"]["num_key_value_heads"],
rms_norm_eps=config["text_config"]["rms_norm_eps"],
mrope_section=(ctypes.c_ulong * 3)(
*config["text_config"]["rope_scaling"]["mrope_section"]
),
rope_theta=config["text_config"]["rope_theta"],
vocab_size=config["text_config"]["vocab_size"],
),
vis_meta=VisMetaCStruct(
depth=config["vision_config"]["depth"],
deepstack_visual_indexes=(ctypes.c_ulong * 3)(
*config["vision_config"]["deepstack_visual_indexes"]
),
hidden_size=config["vision_config"]["hidden_size"],
in_channels=config["vision_config"]["in_channels"],
initializer_range=config["vision_config"]["initializer_range"],
intermediate_size=config["vision_config"]["intermediate_size"],
num_heads=config["vision_config"]["num_heads"],
num_position_embeddings=config["vision_config"][
"num_position_embeddings"
],
out_hidden_size=config["vision_config"]["out_hidden_size"],
patch_size=config["vision_config"]["patch_size"],
spatial_merge_size=config["vision_config"]["spatial_merge_size"],
temporal_patch_size=config["vision_config"]["temporal_patch_size"],
), ),
vis_meta = VisMetaCStruct(
depth = config['vision_config']['depth'],
deepstack_visual_indexes = (ctypes.c_ulong * 3)(*config['vision_config']['deepstack_visual_indexes']),
hidden_size = config['vision_config']['hidden_size'],
in_channels = config['vision_config']['in_channels'],
initializer_range = config['vision_config']['initializer_range'],
intermediate_size = config['vision_config']['intermediate_size'],
num_heads = config['vision_config']['num_heads'],
num_position_embeddings = config['vision_config']['num_position_embeddings'],
out_hidden_size = config['vision_config']['out_hidden_size'],
patch_size = config['vision_config']['patch_size'],
spatial_merge_size = config['vision_config']['spatial_merge_size'],
temporal_patch_size = config['vision_config']['temporal_patch_size']
)
) )
def load_specific_tensor(model_dir, tensor_name): def load_specific_tensor(model_dir, tensor_name):
""" """
Load a specific tensor from a safetensors model. Load a specific tensor from a safetensors model.
...@@ -206,6 +246,7 @@ def load_specific_tensor(model_dir, tensor_name): ...@@ -206,6 +246,7 @@ def load_specific_tensor(model_dir, tensor_name):
# If we reach here, tensor was not found in any file # If we reach here, tensor was not found in any file
raise KeyError(f"{tensor_name} not found in any .safetensors files") raise KeyError(f"{tensor_name} not found in any .safetensors files")
def load_Qwen3vl_weights( def load_Qwen3vl_weights(
meta: Qwen3vlMeta, meta: Qwen3vlMeta,
weights, weights,
...@@ -233,30 +274,40 @@ def load_Qwen3vl_weights( ...@@ -233,30 +274,40 @@ def load_Qwen3vl_weights(
# ------------------------------- # -------------------------------
# Language_model weights # Language_model weights
# ------------------------------- # -------------------------------
input_embd = load_specific_tensor(model_path, lang_names.input_embd()).to(meta.torch_dtype) input_embd = load_specific_tensor(model_path, lang_names.input_embd()).to(
meta.torch_dtype
)
weight_loader.contents.lang_loader.load_input_embd(weights, input_embd.data_ptr()) weight_loader.contents.lang_loader.load_input_embd(weights, input_embd.data_ptr())
del input_embd del input_embd
output_norm = load_specific_tensor(model_path, lang_names.output_norm()).to(meta.torch_dtype) output_norm = load_specific_tensor(model_path, lang_names.output_norm()).to(
meta.torch_dtype
)
weight_loader.contents.lang_loader.load_output_norm(weights, output_norm.data_ptr()) weight_loader.contents.lang_loader.load_output_norm(weights, output_norm.data_ptr())
del output_norm del output_norm
output_embd = load_specific_tensor(model_path, lang_names.output_embd()).to(meta.torch_dtype) output_embd = load_specific_tensor(model_path, lang_names.output_embd()).to(
meta.torch_dtype
)
weight_loader.contents.lang_loader.load_output_embd(weights, output_embd.data_ptr()) weight_loader.contents.lang_loader.load_output_embd(weights, output_embd.data_ptr())
del output_embd del output_embd
for i in range(meta.text_meta.num_hidden_layers): for i in range(meta.text_meta.num_hidden_layers):
attn_norm = load_specific_tensor(model_path, lang_names.attn_norm(i)).to(meta.torch_dtype) attn_norm = load_specific_tensor(model_path, lang_names.attn_norm(i)).to(
weight_loader.contents.lang_loader.load_attn_norm(weights, attn_norm.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.lang_loader.load_attn_norm(
weights, attn_norm.data_ptr(), i
)
del attn_norm del attn_norm
attn_q_proj = load_specific_tensor(model_path, lang_names.attn_q_proj(i)) attn_q_proj = load_specific_tensor(model_path, lang_names.attn_q_proj(i))
attn_k_proj = load_specific_tensor(model_path, lang_names.attn_k_proj(i)) attn_k_proj = load_specific_tensor(model_path, lang_names.attn_k_proj(i))
attn_v_proj = load_specific_tensor(model_path, lang_names.attn_v_proj(i)) attn_v_proj = load_specific_tensor(model_path, lang_names.attn_v_proj(i))
_Q = attn_q_proj.reshape(nh,dh,d) _Q = attn_q_proj.reshape(nh, dh, d)
_K = attn_k_proj.reshape(nkvh,dh,d) _K = attn_k_proj.reshape(nkvh, dh, d)
_V = attn_v_proj.reshape(nkvh,dh,d) _V = attn_v_proj.reshape(nkvh, dh, d)
qkv_proj = [] qkv_proj = []
_nh = nh // ndev _nh = nh // ndev
...@@ -267,24 +318,45 @@ def load_Qwen3vl_weights( ...@@ -267,24 +318,45 @@ def load_Qwen3vl_weights(
qkv_proj.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) qkv_proj.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
attn_qkv_proj = torch.cat(qkv_proj, dim=0).to(meta.torch_dtype).contiguous() attn_qkv_proj = torch.cat(qkv_proj, dim=0).to(meta.torch_dtype).contiguous()
weight_loader.contents.lang_loader.load_attn_qkv_proj(weights, attn_qkv_proj.data_ptr(), i) weight_loader.contents.lang_loader.load_attn_qkv_proj(
weights, attn_qkv_proj.data_ptr(), i
)
del attn_qkv_proj del attn_qkv_proj
attn_q_norm = load_specific_tensor(model_path, lang_names.attn_q_norm(i)).to(meta.torch_dtype) attn_q_norm = load_specific_tensor(model_path, lang_names.attn_q_norm(i)).to(
weight_loader.contents.lang_loader.load_attn_q_norm(weights, attn_q_norm.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.lang_loader.load_attn_q_norm(
weights, attn_q_norm.data_ptr(), i
)
del attn_q_norm del attn_q_norm
attn_k_norm = load_specific_tensor(model_path, lang_names.attn_k_norm(i)).to(meta.torch_dtype) attn_k_norm = load_specific_tensor(model_path, lang_names.attn_k_norm(i)).to(
weight_loader.contents.lang_loader.load_attn_k_norm(weights, attn_k_norm.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.lang_loader.load_attn_k_norm(
weights, attn_k_norm.data_ptr(), i
)
del attn_k_norm del attn_k_norm
attn_o_proj = load_specific_tensor(model_path, lang_names.attn_o_proj(i)) attn_o_proj = load_specific_tensor(model_path, lang_names.attn_o_proj(i))
attn_o_proj = attn_o_proj.to(meta.torch_dtype).reshape([d, ndev, nh // ndev * dh]).transpose(0, 1).contiguous() attn_o_proj = (
weight_loader.contents.lang_loader.load_attn_o_proj(weights, attn_o_proj.data_ptr(), i) attn_o_proj.to(meta.torch_dtype)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
)
weight_loader.contents.lang_loader.load_attn_o_proj(
weights, attn_o_proj.data_ptr(), i
)
del attn_o_proj del attn_o_proj
mlp_norm = load_specific_tensor(model_path, lang_names.mlp_norm(i)).to(meta.torch_dtype) mlp_norm = load_specific_tensor(model_path, lang_names.mlp_norm(i)).to(
weight_loader.contents.lang_loader.load_mlp_norm(weights, mlp_norm.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.lang_loader.load_mlp_norm(
weights, mlp_norm.data_ptr(), i
)
del mlp_norm del mlp_norm
mlp_gate = load_specific_tensor(model_path, lang_names.mlp_gate(i)) mlp_gate = load_specific_tensor(model_path, lang_names.mlp_gate(i))
...@@ -299,130 +371,254 @@ def load_Qwen3vl_weights( ...@@ -299,130 +371,254 @@ def load_Qwen3vl_weights(
gate_up.append(mlp_up[_start:_end, :]) gate_up.append(mlp_up[_start:_end, :])
mlp_gate_up = torch.cat(gate_up, dim=0).to(meta.torch_dtype).contiguous() mlp_gate_up = torch.cat(gate_up, dim=0).to(meta.torch_dtype).contiguous()
weight_loader.contents.lang_loader.load_mlp_gate_up(weights, mlp_gate_up.data_ptr(), i) weight_loader.contents.lang_loader.load_mlp_gate_up(
weights, mlp_gate_up.data_ptr(), i
)
del mlp_gate_up del mlp_gate_up
mlp_down = load_specific_tensor(model_path, lang_names.mlp_down(i)) mlp_down = load_specific_tensor(model_path, lang_names.mlp_down(i))
mlp_down = mlp_down.to(meta.torch_dtype).reshape([d, ndev, di // ndev]).transpose(0, 1).contiguous() mlp_down = (
weight_loader.contents.lang_loader.load_mlp_down(weights, mlp_down.data_ptr(), i) mlp_down.to(meta.torch_dtype)
.reshape([d, ndev, di // ndev])
.transpose(0, 1)
.contiguous()
)
weight_loader.contents.lang_loader.load_mlp_down(
weights, mlp_down.data_ptr(), i
)
del mlp_down del mlp_down
# ------------------------------- # -------------------------------
# Vision head weights # Vision head weights
# ------------------------------- # -------------------------------
patch_embed_weight = load_specific_tensor(model_path, vis_names.patch_embed_weight()).to(meta.torch_dtype) patch_embed_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_patch_embed_weight(weights, patch_embed_weight.data_ptr()) model_path, vis_names.patch_embed_weight()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_patch_embed_weight(
weights, patch_embed_weight.data_ptr()
)
del patch_embed_weight del patch_embed_weight
patch_embed_bias = load_specific_tensor(model_path, vis_names.patch_embed_bias()).to(meta.torch_dtype) patch_embed_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_patch_embed_bias(weights, patch_embed_bias.data_ptr()) model_path, vis_names.patch_embed_bias()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_patch_embed_bias(
weights, patch_embed_bias.data_ptr()
)
del patch_embed_bias del patch_embed_bias
pos_embed_weight = load_specific_tensor(model_path, vis_names.pos_embed_weight()).to(meta.torch_dtype) pos_embed_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_pos_embed_weight(weights, pos_embed_weight.data_ptr()) model_path, vis_names.pos_embed_weight()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_pos_embed_weight(
weights, pos_embed_weight.data_ptr()
)
del pos_embed_weight del pos_embed_weight
for i in range(meta.vis_meta.depth): for i in range(meta.vis_meta.depth):
attn_proj_weight = load_specific_tensor(model_path, vis_names.attn_proj_weight(i)).to(meta.torch_dtype) attn_proj_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_attn_proj_weight(weights, attn_proj_weight.data_ptr(), i) model_path, vis_names.attn_proj_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_attn_proj_weight(
weights, attn_proj_weight.data_ptr(), i
)
del attn_proj_weight del attn_proj_weight
attn_proj_bias = load_specific_tensor(model_path, vis_names.attn_proj_bias(i)).to(meta.torch_dtype) attn_proj_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_attn_proj_bias(weights, attn_proj_bias.data_ptr(), i) model_path, vis_names.attn_proj_bias(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_attn_proj_bias(
weights, attn_proj_bias.data_ptr(), i
)
del attn_proj_bias del attn_proj_bias
attn_qkv_weight = load_specific_tensor(model_path, vis_names.attn_qkv_weight(i)).to(meta.torch_dtype) attn_qkv_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_attn_qkv_weight(weights, attn_qkv_weight.data_ptr(), i) model_path, vis_names.attn_qkv_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_attn_qkv_weight(
weights, attn_qkv_weight.data_ptr(), i
)
del attn_qkv_weight del attn_qkv_weight
attn_qkv_bias = load_specific_tensor(model_path, vis_names.attn_qkv_bias(i)).to(meta.torch_dtype) attn_qkv_bias = load_specific_tensor(model_path, vis_names.attn_qkv_bias(i)).to(
weight_loader.contents.vis_loader.load_attn_qkv_bias(weights, attn_qkv_bias.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.vis_loader.load_attn_qkv_bias(
weights, attn_qkv_bias.data_ptr(), i
)
del attn_qkv_bias del attn_qkv_bias
mlp_linear_fc1_weight = load_specific_tensor(model_path, vis_names.mlp_linear_fc1_weight(i)).to(meta.torch_dtype) mlp_linear_fc1_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_mlp_linear_fc1_weight(weights, mlp_linear_fc1_weight.data_ptr(), i) model_path, vis_names.mlp_linear_fc1_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_mlp_linear_fc1_weight(
weights, mlp_linear_fc1_weight.data_ptr(), i
)
del mlp_linear_fc1_weight del mlp_linear_fc1_weight
mlp_linear_fc1_bias = load_specific_tensor(model_path, vis_names.mlp_linear_fc1_bias(i)).to(meta.torch_dtype) mlp_linear_fc1_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_mlp_linear_fc1_bias(weights, mlp_linear_fc1_bias.data_ptr(), i) model_path, vis_names.mlp_linear_fc1_bias(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_mlp_linear_fc1_bias(
weights, mlp_linear_fc1_bias.data_ptr(), i
)
del mlp_linear_fc1_bias del mlp_linear_fc1_bias
mlp_linear_fc2_weight = load_specific_tensor(model_path, vis_names.mlp_linear_fc2_weight(i)).to(meta.torch_dtype) mlp_linear_fc2_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_mlp_linear_fc2_weight(weights, mlp_linear_fc2_weight.data_ptr(), i) model_path, vis_names.mlp_linear_fc2_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_mlp_linear_fc2_weight(
weights, mlp_linear_fc2_weight.data_ptr(), i
)
del mlp_linear_fc2_weight del mlp_linear_fc2_weight
mlp_linear_fc2_bias = load_specific_tensor(model_path, vis_names.mlp_linear_fc2_bias(i)).to(meta.torch_dtype) mlp_linear_fc2_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_mlp_linear_fc2_bias(weights, mlp_linear_fc2_bias.data_ptr(), i) model_path, vis_names.mlp_linear_fc2_bias(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_mlp_linear_fc2_bias(
weights, mlp_linear_fc2_bias.data_ptr(), i
)
del mlp_linear_fc2_bias del mlp_linear_fc2_bias
norm1_weight = load_specific_tensor(model_path, vis_names.norm1_weight(i)).to(meta.torch_dtype) norm1_weight = load_specific_tensor(model_path, vis_names.norm1_weight(i)).to(
weight_loader.contents.vis_loader.load_norm1_weight(weights, norm1_weight.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.vis_loader.load_norm1_weight(
weights, norm1_weight.data_ptr(), i
)
del norm1_weight del norm1_weight
norm1_bias = load_specific_tensor(model_path, vis_names.norm1_bias(i)).to(meta.torch_dtype) norm1_bias = load_specific_tensor(model_path, vis_names.norm1_bias(i)).to(
weight_loader.contents.vis_loader.load_norm1_bias(weights, norm1_bias.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.vis_loader.load_norm1_bias(
weights, norm1_bias.data_ptr(), i
)
del norm1_bias del norm1_bias
norm2_weight = load_specific_tensor(model_path, vis_names.norm2_weight(i)).to(meta.torch_dtype) norm2_weight = load_specific_tensor(model_path, vis_names.norm2_weight(i)).to(
weight_loader.contents.vis_loader.load_norm2_weight(weights, norm2_weight.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.vis_loader.load_norm2_weight(
weights, norm2_weight.data_ptr(), i
)
del norm2_weight del norm2_weight
norm2_bias = load_specific_tensor(model_path, vis_names.norm2_bias(i)).to(meta.torch_dtype) norm2_bias = load_specific_tensor(model_path, vis_names.norm2_bias(i)).to(
weight_loader.contents.vis_loader.load_norm2_bias(weights, norm2_bias.data_ptr(), i) meta.torch_dtype
)
weight_loader.contents.vis_loader.load_norm2_bias(
weights, norm2_bias.data_ptr(), i
)
del norm2_bias del norm2_bias
for i in range(len(meta.vis_meta.deepstack_visual_indexes)): for i in range(len(meta.vis_meta.deepstack_visual_indexes)):
deepstack_merger_linear_fc1_weight = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc1_weight(i)).to(meta.torch_dtype) deepstack_merger_linear_fc1_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_weight(weights, deepstack_merger_linear_fc1_weight.data_ptr(), i) model_path, vis_names.deepstack_merger_linear_fc1_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_weight(
weights, deepstack_merger_linear_fc1_weight.data_ptr(), i
)
del deepstack_merger_linear_fc1_weight del deepstack_merger_linear_fc1_weight
deepstack_merger_linear_fc1_bias = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc1_bias(i)).to(meta.torch_dtype) deepstack_merger_linear_fc1_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_bias(weights, deepstack_merger_linear_fc1_bias.data_ptr(), i) model_path, vis_names.deepstack_merger_linear_fc1_bias(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_bias(
weights, deepstack_merger_linear_fc1_bias.data_ptr(), i
)
del deepstack_merger_linear_fc1_bias del deepstack_merger_linear_fc1_bias
deepstack_merger_linear_fc2_weight = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc2_weight(i)).to(meta.torch_dtype) deepstack_merger_linear_fc2_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_weight(weights, deepstack_merger_linear_fc2_weight.data_ptr(), i) model_path, vis_names.deepstack_merger_linear_fc2_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_weight(
weights, deepstack_merger_linear_fc2_weight.data_ptr(), i
)
del deepstack_merger_linear_fc2_weight del deepstack_merger_linear_fc2_weight
deepstack_merger_linear_fc2_bias = load_specific_tensor(model_path, vis_names.deepstack_merger_linear_fc2_bias(i)).to(meta.torch_dtype) deepstack_merger_linear_fc2_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_bias(weights, deepstack_merger_linear_fc2_bias.data_ptr(), i) model_path, vis_names.deepstack_merger_linear_fc2_bias(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_bias(
weights, deepstack_merger_linear_fc2_bias.data_ptr(), i
)
del deepstack_merger_linear_fc2_bias del deepstack_merger_linear_fc2_bias
deepstack_merger_norm_weight = load_specific_tensor(model_path, vis_names.deepstack_merger_norm_weight(i)).to(meta.torch_dtype) deepstack_merger_norm_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_deepstack_merger_norm_weight(weights, deepstack_merger_norm_weight.data_ptr(), i) model_path, vis_names.deepstack_merger_norm_weight(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_deepstack_merger_norm_weight(
weights, deepstack_merger_norm_weight.data_ptr(), i
)
del deepstack_merger_norm_weight del deepstack_merger_norm_weight
deepstack_merger_norm_bias = load_specific_tensor(model_path, vis_names.deepstack_merger_norm_bias(i)).to(meta.torch_dtype) deepstack_merger_norm_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_deepstack_merger_norm_bias(weights, deepstack_merger_norm_bias.data_ptr(), i) model_path, vis_names.deepstack_merger_norm_bias(i)
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_deepstack_merger_norm_bias(
weights, deepstack_merger_norm_bias.data_ptr(), i
)
del deepstack_merger_norm_bias del deepstack_merger_norm_bias
merger_linear_fc1_weight = load_specific_tensor(model_path, vis_names.merger_linear_fc1_weight()).to(meta.torch_dtype) merger_linear_fc1_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_merger_linear_fc1_weight(weights, merger_linear_fc1_weight.data_ptr()) model_path, vis_names.merger_linear_fc1_weight()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_merger_linear_fc1_weight(
weights, merger_linear_fc1_weight.data_ptr()
)
del merger_linear_fc1_weight del merger_linear_fc1_weight
merger_linear_fc1_bias = load_specific_tensor(model_path, vis_names.merger_linear_fc1_bias()).to(meta.torch_dtype) merger_linear_fc1_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_merger_linear_fc1_bias(weights, merger_linear_fc1_bias.data_ptr()) model_path, vis_names.merger_linear_fc1_bias()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_merger_linear_fc1_bias(
weights, merger_linear_fc1_bias.data_ptr()
)
del merger_linear_fc1_bias del merger_linear_fc1_bias
merger_linear_fc2_weight = load_specific_tensor(model_path, vis_names.merger_linear_fc2_weight()).to(meta.torch_dtype) merger_linear_fc2_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_merger_linear_fc2_weight(weights, merger_linear_fc2_weight.data_ptr()) model_path, vis_names.merger_linear_fc2_weight()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_merger_linear_fc2_weight(
weights, merger_linear_fc2_weight.data_ptr()
)
del merger_linear_fc2_weight del merger_linear_fc2_weight
merger_linear_fc2_bias = load_specific_tensor(model_path, vis_names.merger_linear_fc2_bias()).to(meta.torch_dtype) merger_linear_fc2_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_merger_linear_fc2_bias(weights, merger_linear_fc2_bias.data_ptr()) model_path, vis_names.merger_linear_fc2_bias()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_merger_linear_fc2_bias(
weights, merger_linear_fc2_bias.data_ptr()
)
del merger_linear_fc2_bias del merger_linear_fc2_bias
merger_norm_weight = load_specific_tensor(model_path, vis_names.merger_norm_weight()).to(meta.torch_dtype) merger_norm_weight = load_specific_tensor(
weight_loader.contents.vis_loader.load_merger_norm_weight(weights, merger_norm_weight.data_ptr()) model_path, vis_names.merger_norm_weight()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_merger_norm_weight(
weights, merger_norm_weight.data_ptr()
)
del merger_norm_weight del merger_norm_weight
merger_norm_bias = load_specific_tensor(model_path, vis_names.merger_norm_bias()).to(meta.torch_dtype) merger_norm_bias = load_specific_tensor(
weight_loader.contents.vis_loader.load_merger_norm_bias(weights, merger_norm_bias.data_ptr()) model_path, vis_names.merger_norm_bias()
).to(meta.torch_dtype)
weight_loader.contents.vis_loader.load_merger_norm_bias(
weights, merger_norm_bias.data_ptr()
)
del merger_norm_bias del merger_norm_bias
class Qwen3vlBatchedTask: class Qwen3vlBatchedTask:
def __init__(self, tasks: List[InferTask]): def __init__(
self,
tasks: List[InferTask],
all_pixel_values=None,
all_image_grid_thw=None,
all_pixel_values_videos=None,
all_video_grid_thw=None,
):
self.tasks = tasks self.tasks = tasks
self.nreq = len(tasks) self.nreq = len(tasks)
...@@ -443,9 +639,7 @@ class Qwen3vlBatchedTask: ...@@ -443,9 +639,7 @@ class Qwen3vlBatchedTask:
self.tokens = (c_uint * self.ntok)(*flat_tokens) self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(Qwen3vlCacheCStruct) * self.nreq)( self.kv_caches = (POINTER(Qwen3vlCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
*self.kv_cache_ptrs
)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list) self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list) self.topps = (c_float * self.nreq)(*self.topps_list)
...@@ -462,38 +656,60 @@ class Qwen3vlBatchedTask: ...@@ -462,38 +656,60 @@ class Qwen3vlBatchedTask:
self.patch_features = 0 self.patch_features = 0
# Prepare visual encoder inputs # Prepare visual encoder inputs
all_pixel_values = [t.inputs['pixel_values'] for t in tasks if 'pixel_values' in t.inputs] # all_pixel_values = [t.inputs['pixel_values'] for t in tasks if 'pixel_values' in t.inputs]
all_image_grid_thw = [t.inputs['image_grid_thw'] for t in tasks if 'image_grid_thw' in t.inputs] # all_image_grid_thw = [t.inputs['image_grid_thw'] for t in tasks if 'image_grid_thw' in t.inputs]
all_pixel_values_videos = [t.inputs['pixel_values_videos'] for t in tasks if 'pixel_values_videos' in t.inputs] # all_pixel_values_videos = [t.inputs['pixel_values_videos'] for t in tasks if 'pixel_values_videos' in t.inputs]
all_video_grid_thw = [t.inputs['video_grid_thw'] for t in tasks if 'video_grid_thw' in t.inputs] # all_video_grid_thw = [t.inputs['video_grid_thw'] for t in tasks if 'video_grid_thw' in t.inputs]
if all_pixel_values: if all_pixel_values is not None:
concat_pixel_values = torch.cat(all_pixel_values, dim=0) # (total_patches, features) print(all_pixel_values.shape)
concat_pixel_values = (
torch.cat(all_pixel_values, dim=0)
if isinstance(all_pixel_values, list)
else all_pixel_values
) # (total_patches, features)
self.total_patches = concat_pixel_values.shape[0] self.total_patches = concat_pixel_values.shape[0]
self.patch_features = concat_pixel_values.shape[1] self.patch_features = concat_pixel_values.shape[1]
self.flat_pixels = concat_pixel_values.flatten().to(torch.bfloat16).contiguous() self.flat_pixels = (
self.pixel_values = self.flat_pixels.ctypes.data_as(c_void_p) concat_pixel_values.flatten().to(torch.bfloat16).contiguous()
)
if all_image_grid_thw: self.pixel_values = self.flat_pixels.data_ptr()
concat_grid_thw = torch.cat(all_image_grid_thw, dim=0) # (total_images, 3)
if all_image_grid_thw is not None:
concat_grid_thw = (
torch.cat(all_image_grid_thw, dim=0)
if isinstance(all_image_grid_thw, list)
else all_image_grid_thw
) # (total_images, 3)
self.num_images = concat_grid_thw.shape[0] self.num_images = concat_grid_thw.shape[0]
flat_grid = concat_grid_thw.flatten().to(torch.int32).contiguous() self.flat_grid = (
self.image_grid_thw = (c_uint * len(flat_grid))(*flat_grid.tolist()) concat_grid_thw.flatten().to(torch.int32).contiguous().tolist()
)
self.image_grid_thw = (c_uint * len(self.flat_grid))(*self.flat_grid)
if all_pixel_values_videos: if all_pixel_values_videos is not None:
concat_pixel_values_videos = torch.cat(all_pixel_values_videos, dim=0) # (total_patches_videos, features) concat_pixel_values_videos = torch.cat(
all_pixel_values_videos, dim=0
) # (total_patches_videos, features)
self.total_patches_videos = concat_pixel_values_videos.shape[0] self.total_patches_videos = concat_pixel_values_videos.shape[0]
self.patch_features_videos = concat_pixel_values_videos.shape[1] self.patch_features_videos = concat_pixel_values_videos.shape[1]
print(self.patch_features_videos, flush=True) print(self.patch_features_videos, flush=True)
self.flat_pixels_videos = concat_pixel_values_videos.flatten().to(torch.bfloat16).contiguous() self.flat_pixels_videos = (
concat_pixel_values_videos.flatten().to(torch.bfloat16).contiguous()
)
self.pixel_values_videos = self.flat_pixels_videos.ctypes.data_as(c_void_p) self.pixel_values_videos = self.flat_pixels_videos.ctypes.data_as(c_void_p)
if all_video_grid_thw: if all_video_grid_thw is not None:
concat_grid_thw_videos = torch.cat(all_video_grid_thw, dim=0) # (total_videos, 3) concat_grid_thw_videos = torch.cat(
all_video_grid_thw, dim=0
) # (total_videos, 3)
self.num_videos = concat_grid_thw_videos.shape[0] self.num_videos = concat_grid_thw_videos.shape[0]
flat_grid_videos = concat_grid_thw_videos.flatten().to(torch.int32).contiguous() flat_grid_videos = (
self.video_grid_thw = (c_uint * len(flat_grid_videos))(*flat_grid_videos.tolist()) concat_grid_thw_videos.flatten().to(torch.int32).contiguous()
)
self.video_grid_thw = (c_uint * len(flat_grid_videos))(
*flat_grid_videos.tolist()
)
def input_args(self): def input_args(self):
return ( return (
...@@ -517,6 +733,7 @@ class Qwen3vlBatchedTask: ...@@ -517,6 +733,7 @@ class Qwen3vlBatchedTask:
self.topps, self.topps,
) )
# 需要处理 visual encoder的cache 和 image video输入 # 需要处理 visual encoder的cache 和 image video输入
class Qwen3vlForCauslLM: class Qwen3vlForCauslLM:
def __init__( def __init__(
...@@ -533,9 +750,7 @@ class Qwen3vlForCauslLM: ...@@ -533,9 +750,7 @@ class Qwen3vlForCauslLM:
print(model_dir_path) print(model_dir_path)
if "qwen3_vl" == config["model_type"]: if "qwen3_vl" == config["model_type"]:
self.meta = Qwen3vlMeta( self.meta = Qwen3vlMeta(config, max_tokens=max_tokens)
config, max_tokens=max_tokens
)
self.processor = transformers.AutoProcessor.from_pretrained(model_dir_path) self.processor = transformers.AutoProcessor.from_pretrained(model_dir_path)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
else: else:
...@@ -547,11 +762,7 @@ class Qwen3vlForCauslLM: ...@@ -547,11 +762,7 @@ class Qwen3vlForCauslLM:
self.model_instance = Qwen3vlModel() self.model_instance = Qwen3vlModel()
weights = self.model_instance.create_weights( weights = self.model_instance.create_weights(
byref(self.meta), byref(self.meta), device, ndev, dev_ids, c_bool(True)
device,
ndev,
dev_ids,
c_bool(True)
) )
print("Loading weights...") print("Loading weights...")
# Load weights from host # Load weights from host
...@@ -573,9 +784,22 @@ class Qwen3vlForCauslLM: ...@@ -573,9 +784,22 @@ class Qwen3vlForCauslLM:
def drop_kv_cache(self, kv_cache): def drop_kv_cache(self, kv_cache):
self.model_instance.drop_cache(self.model_ptr, kv_cache) self.model_instance.drop_cache(self.model_ptr, kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]): def batch_infer_one_round(
self,
tasks: List[InferTask],
all_pixel_values=None,
all_image_grid_thw=None,
all_pixel_values_videos=None,
all_video_grid_thw=None,
):
output = (c_uint * len(tasks))() output = (c_uint * len(tasks))()
batch_inputs = Qwen3vlBatchedTask(tasks) batch_inputs = Qwen3vlBatchedTask(
tasks,
all_pixel_values,
all_image_grid_thw,
all_pixel_values_videos,
all_video_grid_thw,
)
self.model_instance.infer_batch( self.model_instance.infer_batch(
self.model_ptr, self.model_ptr,
*(batch_inputs.input_args()), *(batch_inputs.input_args()),
...@@ -583,18 +807,31 @@ class Qwen3vlForCauslLM: ...@@ -583,18 +807,31 @@ class Qwen3vlForCauslLM:
) )
return list(output) return list(output)
def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): def generate(
self, input_content, max_steps=0, topp_=1.0, topk_=1, temperature_=1.0
):
inputs = self.processor.apply_chat_template( inputs = self.processor.apply_chat_template(
conversation = [{"role": "user","content": [{"type": "text", "text": input_content}]}], conversation=[{"role": "user", "content": input_content}],
tokenize=True, tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
return_dict=True, return_dict=True,
return_tensors="pt", return_tensors="pt",
) )
tokens = inputs["input_ids"][0].tolist()
pixel_values = inputs["pixel_values"] if "pixel_values" in inputs else None
image_grid_thw = (
inputs["image_grid_thw"] if "image_grid_thw" in inputs else None
)
pixel_values_videos = (
inputs["pixel_values_videos"] if "pixel_values_videos" in inputs else None
)
video_grid_thw = (
inputs["video_grid_thw"] if "video_grid_thw" in inputs else None
)
infer_task = InferTask( infer_task = InferTask(
0, 0,
inputs, tokens,
self.max_context_len(), self.max_context_len(),
temperature_, temperature_,
topk_, topk_,
...@@ -602,22 +839,32 @@ class Qwen3vlForCauslLM: ...@@ -602,22 +839,32 @@ class Qwen3vlForCauslLM:
self.eos_token_id, self.eos_token_id,
) )
infer_task.bind_kvcache(KVCache(self)) infer_task.bind_kvcache(KVCache(self))
print(input_content, end="", flush=True) print(input_content)
steps = 0 steps = 0
total_time = 0 total_time = 0
output_content = "" output_content = ""
print(inputs['input_ids'][0].tolist(), flush=True) # print(inputs['input_ids'][0].tolist(), flush=True)
for step_i in range(max_steps): for step_i in range(max_steps if max_steps > 0 else self.max_context_len()):
start_time = time.time() start_time = time.time()
output_tokens = self.batch_infer_one_round([infer_task]) output_tokens = self.batch_infer_one_round(
print(output_tokens) [infer_task],
pixel_values,
image_grid_thw,
pixel_values_videos,
video_grid_thw,
)
# print(output_tokens)
end_time = time.time() end_time = time.time()
steps += 1 steps += 1
output_str = self.tokenizer.decode(output_tokens[0]) output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str output_content += output_str
print(output_str, end="", flush=True) print(output_str, end="", flush=True)
pixel_values = None
image_grid_thw = None
pixel_values_videos = None
video_grid_thw = None
if output_tokens[0] in self.eos_token_id: if output_tokens[0] in self.eos_token_id:
break break
infer_task.next(output_tokens[0]) infer_task.next(output_tokens[0])
...@@ -627,7 +874,7 @@ class Qwen3vlForCauslLM: ...@@ -627,7 +874,7 @@ class Qwen3vlForCauslLM:
print("\n") print("\n")
avg_time = total_time * 1000 / steps if steps > 0 else -1 avg_time = total_time * 1000 / steps if steps > 0 else -1
print(output_content, flush=True) # print(output_content, flush=True)
print(f"Time per step: {avg_time:.3f}ms") print(f"Time per step: {avg_time:.3f}ms")
infer_task._kv_cache.drop(self) infer_task._kv_cache.drop(self)
...@@ -665,10 +912,22 @@ def test(): ...@@ -665,10 +912,22 @@ def test():
"Usage: python qwen3vl.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]" "Usage: python qwen3vl.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
) )
sys.exit(1) sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
img_url = None
if len(sys.argv) > 4:
img_url = sys.argv[4]
model = Qwen3vlForCauslLM(model_path, device_type, ndev, max_tokens=1024) model = Qwen3vlForCauslLM(model_path, device_type, ndev, max_tokens=1024)
model.generate("山东最高的山是?", 200) input_content = (
[
{"type": "text", "text": "Describe this image."},
{"type": "image", "url": img_url},
]
if img_url is not None
else [{"type": "text", "text": "山东最高的山是?"}]
)
model.generate(input_content)
model.destroy_model_instance() model.destroy_model_instance()
......
...@@ -16,7 +16,7 @@ public: ...@@ -16,7 +16,7 @@ public:
class MemoryPool : public AllocatorBase { class MemoryPool : public AllocatorBase {
public: public:
static constexpr size_t DEFAULT_ALIGNMENT = 256; static constexpr size_t DEFAULT_ALIGNMENT = 512;
explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT); explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT);
~MemoryPool(); ~MemoryPool();
......
...@@ -43,9 +43,9 @@ void InferenceContext::conv(std::shared_ptr<Tensor> y, ...@@ -43,9 +43,9 @@ void InferenceContext::conv(std::shared_ptr<Tensor> y,
size_t n) { size_t n) {
size_t key = CacheManager::createDescriptorKey(y, x, w, bias); size_t key = CacheManager::createDescriptorKey(y, x, w, bias);
// Combine additional parameters into the key for unique identification // Combine additional parameters into the key for unique identification
hash_combine(key, std::hash<void*>()(pads)); hash_combine(key, std::hash<void *>()(pads));
hash_combine(key, std::hash<void*>()(strides)); hash_combine(key, std::hash<void *>()(strides));
hash_combine(key, std::hash<void*>()(dilations)); hash_combine(key, std::hash<void *>()(dilations));
hash_combine(key, std::hash<size_t>()(n)); hash_combine(key, std::hash<size_t>()(n));
infiniopConvDescriptor_t desc; infiniopConvDescriptor_t desc;
......
...@@ -48,14 +48,14 @@ void releaseDeviceResource(Qwen3vlDeviceResource &res) { ...@@ -48,14 +48,14 @@ void releaseDeviceResource(Qwen3vlDeviceResource &res) {
res.comm = nullptr; res.comm = nullptr;
} }
inline std::shared_ptr<Tensor> get_custom_SinTable(const Qwen3vlMeta &meta, std::vector<std::vector<uint32_t>> &pos_ids ,uint32_t dim, size_t theta) { inline std::shared_ptr<Tensor> get_custom_SinTable(const Qwen3vlMeta &meta, std::vector<std::vector<uint32_t>> &pos_ids, uint32_t dim, size_t theta) {
// pos_ids shape:[seq, dim/2] , pos ids acting on each dim // pos_ids shape:[seq, dim/2] , pos ids acting on each dim
auto unit = dsize(meta.dtype); auto unit = dsize(meta.dtype);
auto half_dim = dim/2; auto half_dim = dim / 2;
size_t len = pos_ids.size(); size_t len = pos_ids.size();
void *table = std::malloc(len * half_dim * unit); void *table = std::malloc(len * half_dim * unit);
for (size_t i = 0; i <len; i++) { for (size_t i = 0; i < len; i++) {
for (size_t j = 0; j < half_dim; j++) { for (size_t j = 0; j < half_dim; j++) {
float _cos = std::sin( float _cos = std::sin(
static_cast<float>(pos_ids[i][j]) / std::pow(theta, static_cast<float>(j) / half_dim)); static_cast<float>(pos_ids[i][j]) / std::pow(theta, static_cast<float>(j) / half_dim));
...@@ -77,14 +77,14 @@ inline std::shared_ptr<Tensor> get_custom_SinTable(const Qwen3vlMeta &meta, std: ...@@ -77,14 +77,14 @@ inline std::shared_ptr<Tensor> get_custom_SinTable(const Qwen3vlMeta &meta, std:
return tensor; return tensor;
} }
inline std::shared_ptr<Tensor> get_custom_CosTable(const Qwen3vlMeta &meta, std::vector<std::vector<uint32_t>> &pos_ids ,uint32_t dim, size_t theta) { inline std::shared_ptr<Tensor> get_custom_CosTable(const Qwen3vlMeta &meta, std::vector<std::vector<uint32_t>> &pos_ids, uint32_t dim, size_t theta) {
// pos_ids shape:[seq, dim/2] , pos ids acting on each dim // pos_ids shape:[seq, dim/2] , pos ids acting on each dim
auto unit = dsize(meta.dtype); auto unit = dsize(meta.dtype);
auto half_dim = dim/2; auto half_dim = dim / 2;
size_t len = pos_ids.size(); size_t len = pos_ids.size();
void *table = std::malloc(len * half_dim * unit); void *table = std::malloc(len * half_dim * unit);
for (size_t i = 0; i <len; i++) { for (size_t i = 0; i < len; i++) {
for (size_t j = 0; j < half_dim; j++) { for (size_t j = 0; j < half_dim; j++) {
float _cos = std::cos( float _cos = std::cos(
static_cast<float>(pos_ids[i][j]) / std::pow(theta, static_cast<float>(j) / half_dim)); static_cast<float>(pos_ids[i][j]) / std::pow(theta, static_cast<float>(j) / half_dim));
...@@ -107,7 +107,7 @@ inline std::shared_ptr<Tensor> get_custom_CosTable(const Qwen3vlMeta &meta, std: ...@@ -107,7 +107,7 @@ inline std::shared_ptr<Tensor> get_custom_CosTable(const Qwen3vlMeta &meta, std:
} }
inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
uint32_t* grid_thw, uint32_t num_batch, uint32_t total_patches) { uint32_t *grid_thw, uint32_t num_batch, uint32_t total_patches) {
auto dtype = meta.dtype; auto dtype = meta.dtype;
auto num_position_embeddings = meta.vis_meta.num_position_embeddings; auto num_position_embeddings = meta.vis_meta.num_position_embeddings;
auto hidden_size = meta.vis_meta.hidden_size; auto hidden_size = meta.vis_meta.hidden_size;
...@@ -115,7 +115,7 @@ inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &met ...@@ -115,7 +115,7 @@ inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &met
auto num_grid_per_side = static_cast<uint32_t>(sqrt(num_position_embeddings)); auto num_grid_per_side = static_cast<uint32_t>(sqrt(num_position_embeddings));
uint32_t total_pixels_offset = 0; uint32_t total_pixels_offset = 0;
std::shared_ptr<Tensor> patch_pos_embeds = Tensor::buffer(dtype,{total_patches, hidden_size},rsrc.memory_pool); std::shared_ptr<Tensor> patch_pos_embeds = Tensor::buffer(dtype, {total_patches, hidden_size}, rsrc.memory_pool);
auto pos_embed_weight = rsrc.weights->w_vis->pos_embed_weight; auto pos_embed_weight = rsrc.weights->w_vis->pos_embed_weight;
std::vector<std::shared_ptr<Tensor>> pos_embeds(4); std::vector<std::shared_ptr<Tensor>> pos_embeds(4);
...@@ -123,8 +123,8 @@ inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &met ...@@ -123,8 +123,8 @@ inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &met
uint32_t t = grid_thw[i * 3]; uint32_t t = grid_thw[i * 3];
uint32_t h = grid_thw[i * 3 + 1]; uint32_t h = grid_thw[i * 3 + 1];
uint32_t w = grid_thw[i * 3 + 2]; uint32_t w = grid_thw[i * 3 + 2];
auto weight_array = std::vector<uint16_t>(h*w*hidden_size); auto weight_array = std::vector<uint16_t>(h * w * hidden_size);
auto weight_tensor = Tensor::buffer(dtype,{h*w, hidden_size},rsrc.memory_pool); auto weight_tensor = Tensor::buffer(dtype, {h * w, hidden_size}, rsrc.memory_pool);
// 计算插值索引和权重 // 计算插值索引和权重
std::vector<std::vector<uint32_t>> indices(4); std::vector<std::vector<uint32_t>> indices(4);
...@@ -165,64 +165,62 @@ inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &met ...@@ -165,64 +165,62 @@ inline std::shared_ptr<Tensor> fast_pos_embed_interpolate(const Qwen3vlMeta &met
// 查表并加权求和 // 查表并加权求和
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
pos_embeds[j] = Tensor::buffer(dtype,{h*w, hidden_size},rsrc.memory_pool); pos_embeds[j] = Tensor::buffer(dtype, {h * w, hidden_size}, rsrc.memory_pool);
// 使用索引和权重获取对应位置嵌入,并乘以权重 // 使用索引和权重获取对应位置嵌入,并乘以权重
for(size_t i = 0; i < h*w; i++){ for (size_t i = 0; i < h * w; i++) {
rearrange(pos_embeds[j]->slice(0,i,1),pos_embed_weight->slice(0,indices[j][i],1)); rearrange(pos_embeds[j]->slice(0, i, 1), pos_embed_weight->slice(0, indices[j][i], 1));
} }
for(size_t i = 0; i < h*w; i++){ for (size_t i = 0; i < h * w; i++) {
uint16_t w_value = f32_to_bf16(weights[j][i]); uint16_t w_value = f32_to_bf16(weights[j][i]);
for(size_t k=0; k < hidden_size; k++){ for (size_t k = 0; k < hidden_size; k++) {
weight_array[i*hidden_size + k] = w_value; weight_array[i * hidden_size + k] = w_value;
} }
} }
RUN_INFINI(infinirtMemcpyAsync(weight_tensor->data(), weight_array.data(), sizeof(uint16_t)*h*w*hidden_size, RUN_INFINI(infinirtMemcpyAsync(weight_tensor->data(), weight_array.data(), sizeof(uint16_t) * h * w * hidden_size,
INFINIRT_MEMCPY_H2D, rsrc.stream)); INFINIRT_MEMCPY_H2D, rsrc.stream));
mul(pos_embeds[j],pos_embeds[j],weight_tensor); mul(pos_embeds[j], pos_embeds[j], weight_tensor);
} }
// 合并四个方向的结果 // 合并四个方向的结果
auto patch_pos_embed = pos_embeds[0]; // [h*w, hidden_size] auto patch_pos_embed = pos_embeds[0]; // [h*w, hidden_size]
for (int j = 1; j < 4; ++j) { for (int j = 1; j < 4; ++j) {
add(patch_pos_embed,patch_pos_embed, pos_embeds[j]); add(patch_pos_embed, patch_pos_embed, pos_embeds[j]);
} }
// 对于视频帧数T>1的情况,重复patch_pos_embed T次 // 对于视频帧数T>1的情况,重复patch_pos_embed T次
if (t > 1) { if (t > 1) {
auto temp_patch_pos_embed = Tensor::buffer(dtype,{t,h*w,hidden_size},rsrc.memory_pool); auto temp_patch_pos_embed = Tensor::buffer(dtype, {t, h * w, hidden_size}, rsrc.memory_pool);
for(size_t i = 0; i < t; i++){ for (size_t i = 0; i < t; i++) {
rearrange(temp_patch_pos_embed->slice(0,i,1), patch_pos_embed); rearrange(temp_patch_pos_embed->slice(0, i, 1), patch_pos_embed);
} }
patch_pos_embed = temp_patch_pos_embed; patch_pos_embed = temp_patch_pos_embed;
} }
printf("merge patch pos embed/n"); printf("merge patch pos embed/n");
fflush(stdout); fflush(stdout);
patch_pos_embed = patch_pos_embed patch_pos_embed = patch_pos_embed
->view({t, h/merge_size, merge_size, w/merge_size, merge_size, hidden_size}) ->view({t, h / merge_size, merge_size, w / merge_size, merge_size, hidden_size})
->permute({0, 1, 3, 2, 4, 5}) ->permute({0, 1, 3, 2, 4, 5})
->view({t*h*w, hidden_size}); //可能因为内存不连续无法再view ->view({t * h * w, hidden_size}); // 可能因为内存不连续无法再view
rearrange(patch_pos_embeds->slice(0,total_pixels_offset,t*h*w), patch_pos_embed); rearrange(patch_pos_embeds->slice(0, total_pixels_offset, t * h * w), patch_pos_embed);
total_pixels_offset += t*h*w; total_pixels_offset += t * h * w;
} }
return patch_pos_embeds; return patch_pos_embeds;
} }
inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, uint32_t* grid_thw, uint32_t num_batch, uint32_t total_patches) { inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, uint32_t *grid_thw, uint32_t num_batch, uint32_t total_patches) {
auto dtype = meta.dtype; auto dtype = meta.dtype;
auto hidden_size = meta.vis_meta.hidden_size; auto hidden_size = meta.vis_meta.hidden_size;
auto num_heads = meta.vis_meta.num_heads; auto num_heads = meta.vis_meta.num_heads;
auto head_dim = hidden_size / num_heads; auto head_dim = hidden_size / num_heads;
auto merge_size = meta.vis_meta.spatial_merge_size; auto merge_size = meta.vis_meta.spatial_merge_size;
std::vector<std::vector<uint32_t>> pos_ids_table_y ( std::vector<std::vector<uint32_t>> pos_ids_table_y(
total_patches, total_patches,
std::vector<uint32_t>(head_dim/4) std::vector<uint32_t>(head_dim / 4));
); std::vector<std::vector<uint32_t>> pos_ids_table_x(
std::vector<std::vector<uint32_t>> pos_ids_table_x (
total_patches, total_patches,
std::vector<uint32_t>(head_dim/4) std::vector<uint32_t>(head_dim / 4));
);
for (uint32_t b = 0; b < num_batch; ++b) { for (uint32_t b = 0; b < num_batch; ++b) {
uint32_t offset = b * 3; uint32_t offset = b * 3;
uint32_t num_frames = grid_thw[offset + 0]; uint32_t num_frames = grid_thw[offset + 0];
...@@ -243,7 +241,7 @@ inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -243,7 +241,7 @@ inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
// 如果是多帧,重复 num_frames 次 // 如果是多帧,重复 num_frames 次
for (uint32_t f = 0; f < num_frames; ++f) { for (uint32_t f = 0; f < num_frames; ++f) {
size_t dim_offset = 0; size_t dim_offset = 0;
for(;dim_offset<head_dim/4;dim_offset++){ for (; dim_offset < head_dim / 4; dim_offset++) {
pos_ids_table_y[patch_offset][dim_offset] = row; pos_ids_table_y[patch_offset][dim_offset] = row;
pos_ids_table_x[patch_offset][dim_offset] = col; pos_ids_table_x[patch_offset][dim_offset] = col;
} }
...@@ -254,18 +252,18 @@ inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -254,18 +252,18 @@ inline auto rot_pos_embed(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
} }
} }
} }
auto sin = Tensor::buffer(dtype,{total_patches,head_dim/2},rsrc.memory_pool); auto sin = Tensor::buffer(dtype, {total_patches, head_dim / 2}, rsrc.memory_pool);
auto sin_y = get_custom_SinTable(meta,pos_ids_table_y,head_dim/2,10000); auto sin_y = get_custom_SinTable(meta, pos_ids_table_y, head_dim / 2, 10000);
rearrange(sin->slice(1,0,head_dim/4),sin_y); rearrange(sin->slice(1, 0, head_dim / 4), sin_y);
auto sin_x = get_custom_SinTable(meta,pos_ids_table_x,head_dim/2,10000); auto sin_x = get_custom_SinTable(meta, pos_ids_table_x, head_dim / 2, 10000);
rearrange(sin->slice(1,head_dim/4,head_dim/2),sin_y); rearrange(sin->slice(1, head_dim / 4, head_dim / 2), sin_y);
auto cos = Tensor::buffer(dtype,{total_patches,head_dim/2},rsrc.memory_pool); auto cos = Tensor::buffer(dtype, {total_patches, head_dim / 2}, rsrc.memory_pool);
auto cos_y = get_custom_CosTable(meta,pos_ids_table_y,head_dim/2,10000); auto cos_y = get_custom_CosTable(meta, pos_ids_table_y, head_dim / 2, 10000);
rearrange(cos->slice(1,0,head_dim/4),cos_y); rearrange(cos->slice(1, 0, head_dim / 4), cos_y);
auto cos_x = get_custom_CosTable(meta,pos_ids_table_x,head_dim/2,10000); auto cos_x = get_custom_CosTable(meta, pos_ids_table_x, head_dim / 2, 10000);
rearrange(cos->slice(1,head_dim/4,head_dim/2),cos_y); rearrange(cos->slice(1, head_dim / 4, head_dim / 2), cos_y);
return std::pair{sin,cos}; return std::pair{sin, cos};
} }
void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
...@@ -276,20 +274,20 @@ void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc ...@@ -276,20 +274,20 @@ void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc
uint32_t num_images = req.num_images; uint32_t num_images = req.num_images;
void *pixel_values_videos = req.pixel_values_videos; void *pixel_values_videos = req.pixel_values_videos;
uint32_t total_patches_videos = req.total_patches_videos; uint32_t total_patches_videos = req.total_patches_videos;
//uint32_t *video_grid_thw = req.video_grid_thw; // uint32_t *video_grid_thw = req.video_grid_thw;
//uint32_t num_videos = req.num_videos; // uint32_t num_videos = req.num_videos;
//uint32_t patch_features = req.patch_features; // uint32_t patch_features = req.patch_features;
auto dtype = meta.dtype; auto dtype = meta.dtype;
auto d = meta.vis_meta.hidden_size; auto d = meta.vis_meta.hidden_size;
auto channels = meta.vis_meta.in_channels; auto channels = meta.vis_meta.in_channels;
auto patch_size = meta.vis_meta.patch_size; auto patch_size = meta.vis_meta.patch_size;
auto temporal_patch_size = meta.vis_meta.temporal_patch_size; auto temporal_patch_size = meta.vis_meta.temporal_patch_size;
//auto stream = rsrc.stream; // auto stream = rsrc.stream;
auto weights = rsrc.weights; auto weights = rsrc.weights;
auto image_tensor = Tensor::weight(pixel_values, dtype, {total_patches, channels*temporal_patch_size*patch_size*patch_size}); auto image_tensor = Tensor::weight(pixel_values, dtype, {total_patches, channels * temporal_patch_size * patch_size * patch_size});
auto video_tensor = Tensor::weight(pixel_values_videos, dtype, {total_patches_videos, channels*temporal_patch_size*patch_size*patch_size}); auto video_tensor = Tensor::weight(pixel_values_videos, dtype, {total_patches_videos, channels * temporal_patch_size * patch_size * patch_size});
auto hidden_states = Tensor::buffer(dtype, {total_patches, d, 1, 1, 1}, rsrc.memory_pool); auto hidden_states = Tensor::buffer(dtype, {total_patches, d, 1, 1, 1}, rsrc.memory_pool);
std::vector<size_t> pads = {0, 0, 0}; std::vector<size_t> pads = {0, 0, 0};
...@@ -299,12 +297,10 @@ void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc ...@@ -299,12 +297,10 @@ void inferDeviceBatchVision(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc
pads.data(), strides.data(), dilations.data(), 3); pads.data(), strides.data(), dilations.data(), 3);
hidden_states = hidden_states->view({total_patches, d}); hidden_states = hidden_states->view({total_patches, d});
auto pos_embeds = fast_pos_embed_interpolate(meta,rsrc,image_grid_thw,num_images,total_patches); auto pos_embeds = fast_pos_embed_interpolate(meta, rsrc, image_grid_thw, num_images, total_patches);
add(hidden_states,hidden_states,pos_embeds); add(hidden_states, hidden_states, pos_embeds);
auto [sin, cos] = rot_pos_embed(meta,rsrc,image_grid_thw,num_images,total_patches);
auto [sin, cos] = rot_pos_embed(meta, rsrc, image_grid_thw, num_images, total_patches);
} }
void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
...@@ -337,14 +333,14 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -337,14 +333,14 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
auto stream = rsrc.stream; auto stream = rsrc.stream;
auto weights = rsrc.weights; auto weights = rsrc.weights;
//Allocate buffers // Allocate buffers
auto logits_in = Tensor::buffer(dtype, {ntok, d}, rsrc.memory_pool); auto logits_in = Tensor::buffer(dtype, {ntok, d}, rsrc.memory_pool);
auto logits_out = Tensor::buffer(dtype, {ntok, d}, rsrc.memory_pool); auto logits_out = Tensor::buffer(dtype, {ntok, d}, rsrc.memory_pool);
//所有请求的当前token // 所有请求的当前token
auto qkv_buf = Tensor::buffer(dtype, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); auto qkv_buf = Tensor::buffer(dtype, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool);
auto o_buf = Tensor::buffer(dtype, {ntok, nh * dh}, rsrc.memory_pool); auto o_buf = Tensor::buffer(dtype, {ntok, nh * dh}, rsrc.memory_pool);
auto gate_up_buf = Tensor::buffer(dtype, {ntok, 2*di}, rsrc.memory_pool); auto gate_up_buf = Tensor::buffer(dtype, {ntok, 2 * di}, rsrc.memory_pool);
auto prob_buf = Tensor::buffer(dtype, {nreq, dvoc}, rsrc.memory_pool); auto prob_buf = Tensor::buffer(dtype, {nreq, dvoc}, rsrc.memory_pool);
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
...@@ -354,12 +350,12 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -354,12 +350,12 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
auto q_buf = qkv_rope->slice(1, 0, nh); auto q_buf = qkv_rope->slice(1, 0, nh);
auto k_buf = qkv_rope->slice(1, nh, nkvh); auto k_buf = qkv_rope->slice(1, nh, nkvh);
//Prepare inputs // Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok); auto batch_pos_ids = std::vector<uint32_t>(ntok);
size_t req_start = 0; size_t req_start = 0;
for (uint32_t req = 0; req < nreq; req++) { for (uint32_t req = 0; req < nreq; req++) {
for (uint32_t i = 0; i < req_lens[req]; i++) { // req_len 本次query长度,req_pos 历史长度 for (uint32_t i = 0; i < req_lens[req]; i++) { // req_len 本次query长度,req_pos 历史长度
batch_pos_ids[req_start + i] = req_pos[req] + i; //batch_pos_ids 展平后每个token的pos batch_pos_ids[req_start + i] = req_pos[req] + i; // batch_pos_ids 展平后每个token的pos
} }
req_start += req_lens[req]; req_start += req_lens[req];
} }
...@@ -372,7 +368,7 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -372,7 +368,7 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
INFINIRT_MEMCPY_H2D, stream)); INFINIRT_MEMCPY_H2D, stream));
} }
//convert tokens to embeddings // convert tokens to embeddings
for (uint32_t i = 0; i < ntok; i++) { for (uint32_t i = 0; i < ntok; i++) {
RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d),
weights->w_lang->in_embd->data(tokens[i] * d), weights->w_lang->in_embd->data(tokens[i] * d),
...@@ -401,51 +397,51 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -401,51 +397,51 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
auto gate_buf = gate_up_buf->slice(1, 0, di); auto gate_buf = gate_up_buf->slice(1, 0, di);
auto up_buf = gate_up_buf->slice(1, di, di); auto up_buf = gate_up_buf->slice(1, di, di);
//Compute // Compute
for (uint32_t i = 0; i < nlayer; i++){ for (uint32_t i = 0; i < nlayer; i++) {
// attn norm // attn norm
rmsnorm(logits_out,logits_in,weights->w_lang->layers[i].attn_norm,epsilon); rmsnorm(logits_out, logits_in, weights->w_lang->layers[i].attn_norm, epsilon);
// qkv_proj // qkv_proj
linear(qkv_buf,logits_out,weights->w_lang->layers[i].attn_qkv_proj,1.0,0.0,nullptr,nullptr); linear(qkv_buf, logits_out, weights->w_lang->layers[i].attn_qkv_proj, 1.0, 0.0, nullptr, nullptr);
// qk_norm // qk_norm
rmsnorm(q_buf,q_buf,weights->w_lang->layers[i].attn_q_norm,epsilon); rmsnorm(q_buf, q_buf, weights->w_lang->layers[i].attn_q_norm, epsilon);
rmsnorm(k_buf,k_buf,weights->w_lang->layers[i].attn_k_norm,epsilon); rmsnorm(k_buf, k_buf, weights->w_lang->layers[i].attn_k_norm, epsilon);
// rope // rope
rope_v2(q_buf,q_buf,pos_ids_buf,weights->sin_table,weights->cos_table); rope_v2(q_buf, q_buf, pos_ids_buf, weights->sin_table, weights->cos_table);
rope_v2(k_buf,k_buf,pos_ids_buf,weights->sin_table,weights->cos_table); rope_v2(k_buf, k_buf, pos_ids_buf, weights->sin_table, weights->cos_table);
// 逐个req处理 // 逐个req处理
size_t token_offset = 0; size_t token_offset = 0;
for(uint32_t req=0; req < nreq; req++){ for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req]; auto past_len = req_pos[req];
auto seq_len = req_lens[req]; auto seq_len = req_lens[req];
auto total_len = past_len + seq_len; auto total_len = past_len + seq_len;
auto o = o_buf->slice(0,token_offset,seq_len)->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});// [nkvh, ngroup, seq_len, dh] auto o = o_buf->slice(0, token_offset, seq_len)->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); // [nkvh, ngroup, seq_len, dh]
auto q = qkv_rope->slice({{0,token_offset,seq_len},{1,0,nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});// [nkvh, ngroup, seq_len, dh] auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); // [nkvh, ngroup, seq_len, dh]
auto k = qkv_rope->slice({{0,token_offset,seq_len},{1,nh,nkvh}});// [ntok, nkvh, dh] auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); // [ntok, nkvh, dh]
auto v = qkv_rope->slice({{0,token_offset,seq_len},{1,nh+nkvh,nkvh}});// [ntok, nkvh, dh] auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); // [ntok, nkvh, dh]
// concat to cache // concat to cache
rearrange(caches[req]->k_rot[idev][i]->slice(0,past_len,seq_len),k); rearrange(caches[req]->k_rot[idev][i]->slice(0, past_len, seq_len), k);
rearrange(caches[req]->v[idev][i]->slice(0,past_len,seq_len),v); rearrange(caches[req]->v[idev][i]->slice(0, past_len, seq_len), v);
//fill full_k full_v // fill full_k full_v
auto full_k_buff = caches[req]->k_rot[idev][i]->slice(0,0,total_len)->permute({1,2,0});// [nkvh, dh, total_len] auto full_k_buff = caches[req]->k_rot[idev][i]->slice(0, 0, total_len)->permute({1, 2, 0}); // [nkvh, dh, total_len]
auto full_v_buff = caches[req]->v[idev][i]->slice(0,0,total_len)->permute({1,0,2});// [nkvh, total_len, dh] auto full_v_buff = caches[req]->v[idev][i]->slice(0, 0, total_len)->permute({1, 0, 2}); // [nkvh, total_len, dh]
//self-attn // self-attn
rearrange(q_rearrange->slice(2, 0, seq_len), q); rearrange(q_rearrange->slice(2, 0, seq_len), q);
auto attn_score_req = qk_buf->slice(0,0,nh*seq_len*total_len)->view({nkvh, ngroup*seq_len, total_len}); auto attn_score_req = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len});
// [nkvh, ngroup * seq_len, dh] @ [nkvh, dh, total_len] = [nkvh, ngroup * seq_len, total_len] // [nkvh, ngroup * seq_len, dh] @ [nkvh, dh, total_len] = [nkvh, ngroup * seq_len, total_len]
linear(attn_score_req,rearrange_q_buf->slice(1, 0, ngroup * seq_len),full_k_buff,1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); linear(attn_score_req, rearrange_q_buf->slice(1, 0, ngroup * seq_len), full_k_buff, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr);
// softmax // softmax
auto qk_softmax = attn_score_req->view({nh, seq_len, total_len}); auto qk_softmax = attn_score_req->view({nh, seq_len, total_len});
causalSoftmax(qk_softmax,qk_softmax); causalSoftmax(qk_softmax, qk_softmax);
// [nkvh, ngroup * seq_len, total_len] @ [nkvh, total_len, dh] = [nkvh, ngroup * seq_len, dh] // [nkvh, ngroup * seq_len, total_len] @ [nkvh, total_len, dh] = [nkvh, ngroup * seq_len, dh]
linear(attn_val_buf->slice(1, 0, ngroup * seq_len), attn_score_req, full_v_buff, 1.0, 0.0, nullptr, nullptr); linear(attn_val_buf->slice(1, 0, ngroup * seq_len), attn_score_req, full_v_buff, 1.0, 0.0, nullptr, nullptr);
//printf("rearrage o; layer[%d]\n",i); // printf("rearrage o; layer[%d]\n",i);
rearrange(o,attn_val_gemm->slice(2, 0, seq_len)); rearrange(o, attn_val_gemm->slice(2, 0, seq_len));
token_offset += seq_len; token_offset += seq_len;
} }
linear(logits_in, o_buf, weights->w_lang->layers[i].attn_o_proj, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); linear(logits_in, o_buf, weights->w_lang->layers[i].attn_o_proj, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr);
...@@ -458,14 +454,14 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -458,14 +454,14 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
} }
// mlp norm // mlp norm
rmsnorm(logits_out,logits_in,weights->w_lang->layers[i].mlp_norm,epsilon); rmsnorm(logits_out, logits_in, weights->w_lang->layers[i].mlp_norm, epsilon);
// mlp gate_up // mlp gate_up
linear(gate_up_buf,logits_out,weights->w_lang->layers[i].mlp_gate_up,1.0,0.0,nullptr,nullptr); linear(gate_up_buf, logits_out, weights->w_lang->layers[i].mlp_gate_up, 1.0, 0.0, nullptr, nullptr);
// silu // silu
silu(gate_buf,gate_buf); silu(gate_buf, gate_buf);
mul(gate_buf,gate_buf,up_buf); mul(gate_buf, gate_buf, up_buf);
// mlp down // mlp down
linear(logits_in,gate_buf,weights->w_lang->layers[i].mlp_down,1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); linear(logits_in, gate_buf, weights->w_lang->layers[i].mlp_down, 1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr);
// All_reduce if distributed // All_reduce if distributed
if (rsrc.comm != nullptr) { if (rsrc.comm != nullptr) {
RUN_INFINI(infinicclAllReduce( RUN_INFINI(infinicclAllReduce(
...@@ -518,7 +514,7 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -518,7 +514,7 @@ void inferDeviceBatchText(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
uint32_t idev, uint32_t ndev, InferState &state, InferRequest &req) { uint32_t idev, uint32_t ndev, InferState &state, InferRequest &req) {
// infer vision + sync // infer vision + sync
if (req.num_images > 0 || req.num_videos > 0){ if (req.num_images > 0 || req.num_videos > 0) {
inferDeviceBatchVision(meta, rsrc, idev, ndev, req); inferDeviceBatchVision(meta, rsrc, idev, ndev, req);
std::unique_lock<std::mutex> lock(state.mtx_sync); std::unique_lock<std::mutex> lock(state.mtx_sync);
...@@ -526,14 +522,14 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc, ...@@ -526,14 +522,14 @@ void inferDeviceBatch(const Qwen3vlMeta &meta, Qwen3vlDeviceResource &rsrc,
if (state.sync_cnt == 0) { if (state.sync_cnt == 0) {
state.cv_sync.notify_all(); state.cv_sync.notify_all();
} else { } else {
state.cv_sync.wait(lock, [&] {return state.sync_cnt == 0;}); state.cv_sync.wait(lock, [&] { return state.sync_cnt == 0; });
} }
} }
// infer text // infer text
inferDeviceBatchText(meta, rsrc, idev, ndev, req); inferDeviceBatchText(meta, rsrc, idev, ndev, req);
} }
__C void __INFINI_C void
inferBatchQwen3vl(struct Qwen3vlModel *model, inferBatchQwen3vl(struct Qwen3vlModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
void *pixel_values, uint32_t total_patches, void *pixel_values, uint32_t total_patches,
...@@ -581,7 +577,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *model, ...@@ -581,7 +577,7 @@ inferBatchQwen3vl(struct Qwen3vlModel *model,
} }
} }
__C void __INFINI_C void
forwardBatchQwen3vl(struct Qwen3vlModel *model, forwardBatchQwen3vl(struct Qwen3vlModel *model,
const uint32_t *tokens, uint32_t ntok, const uint32_t *tokens, uint32_t ntok,
void *pixel_values, uint32_t total_patches, void *pixel_values, uint32_t total_patches,
...@@ -667,7 +663,6 @@ void launchDevice(const Qwen3vlMeta &meta, std::shared_ptr<Qwen3vlDeviceWeights> ...@@ -667,7 +663,6 @@ void launchDevice(const Qwen3vlMeta &meta, std::shared_ptr<Qwen3vlDeviceWeights>
setInferenceContext(nullptr); // Clear the context when done setInferenceContext(nullptr); // Clear the context when done
} }
Qwen3vlModel::Qwen3vlModel(const Qwen3vlMeta *_meta, const Qwen3vlWeights *weights) : meta(*_meta) { Qwen3vlModel::Qwen3vlModel(const Qwen3vlMeta *_meta, const Qwen3vlWeights *weights) : meta(*_meta) {
auto device_weights = weights->device_weights; auto device_weights = weights->device_weights;
int ndev = device_weights.size(); int ndev = device_weights.size();
...@@ -694,14 +689,14 @@ Qwen3vlModel::Qwen3vlModel(const Qwen3vlMeta *_meta, const Qwen3vlWeights *weigh ...@@ -694,14 +689,14 @@ Qwen3vlModel::Qwen3vlModel(const Qwen3vlMeta *_meta, const Qwen3vlWeights *weigh
} }
} }
__C struct Qwen3vlModel * __INFINI_C struct Qwen3vlModel *
createQwen3vlModel(const Qwen3vlMeta *_meta, createQwen3vlModel(const Qwen3vlMeta *_meta,
const Qwen3vlWeights *weights) { const Qwen3vlWeights *weights) {
Qwen3vlModel *model = new Qwen3vlModel(_meta, weights); Qwen3vlModel *model = new Qwen3vlModel(_meta, weights);
return model; return model;
} }
__C void __INFINI_C void
destroyQwen3vlModel(struct Qwen3vlModel *model) { destroyQwen3vlModel(struct Qwen3vlModel *model) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
......
#include "qwen3vl_impl.hpp" #include "qwen3vl_impl.hpp"
__C struct Qwen3vlCache * __INFINI_C struct Qwen3vlCache *
createQwen3vlCache(const struct Qwen3vlModel *model) { createQwen3vlCache(const struct Qwen3vlModel *model) {
Qwen3vlCache *cache = new Qwen3vlCache(); Qwen3vlCache *cache = new Qwen3vlCache();
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
...@@ -27,7 +27,7 @@ createQwen3vlCache(const struct Qwen3vlModel *model) { ...@@ -27,7 +27,7 @@ createQwen3vlCache(const struct Qwen3vlModel *model) {
//////还有visual deepstack需要cache? //////还有visual deepstack需要cache?
__C void __INFINI_C void
dropQwen3vlCache(const struct Qwen3vlModel *model, dropQwen3vlCache(const struct Qwen3vlModel *model,
struct Qwen3vlCache *cache) { struct Qwen3vlCache *cache) {
auto ndev = model->dev_resources.size(); auto ndev = model->dev_resources.size();
......
...@@ -45,7 +45,6 @@ struct MergerWeight { ...@@ -45,7 +45,6 @@ struct MergerWeight {
std::shared_ptr<Tensor> norm_weight, norm_bias; std::shared_ptr<Tensor> norm_weight, norm_bias;
}; };
struct Qwen3vlVisualEncoderWeight { struct Qwen3vlVisualEncoderWeight {
std::shared_ptr<Tensor> patch_embed_weight, patch_embed_bias, pos_embed_weight; std::shared_ptr<Tensor> patch_embed_weight, patch_embed_bias, pos_embed_weight;
std::vector<Qwen3vlVisBlockWeight> blocks; std::vector<Qwen3vlVisBlockWeight> blocks;
...@@ -53,9 +52,8 @@ struct Qwen3vlVisualEncoderWeight { ...@@ -53,9 +52,8 @@ struct Qwen3vlVisualEncoderWeight {
std::shared_ptr<MergerWeight> merger; std::shared_ptr<MergerWeight> merger;
}; };
struct Qwen3vlDeviceWeights { struct Qwen3vlDeviceWeights {
std::shared_ptr<Tensor> sin_table,cos_table; std::shared_ptr<Tensor> sin_table, cos_table;
std::shared_ptr<Qwen3vlLanguageModelWeight> w_lang; std::shared_ptr<Qwen3vlLanguageModelWeight> w_lang;
std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis; std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis;
infiniDevice_t device; infiniDevice_t device;
......
...@@ -23,7 +23,7 @@ inline std::shared_ptr<Tensor> getOutEmbd( ...@@ -23,7 +23,7 @@ inline std::shared_ptr<Tensor> getOutEmbd(
} }
inline void getLayerWeight( inline void getLayerWeight(
const Qwen3vlMeta *meta, Qwen3vlLayerWeight& layer, int ndev) { const Qwen3vlMeta *meta, Qwen3vlLayerWeight &layer, int ndev) {
auto nkvh = meta->text_meta.num_key_value_heads; auto nkvh = meta->text_meta.num_key_value_heads;
auto nh = meta->text_meta.num_attention_heads; auto nh = meta->text_meta.num_attention_heads;
auto dh = meta->text_meta.head_dim; auto dh = meta->text_meta.head_dim;
...@@ -47,11 +47,10 @@ inline void getLayerWeight( ...@@ -47,11 +47,10 @@ inline void getLayerWeight(
layer.mlp_down = Tensor::weight(nullptr, meta->dtype, down_shape); layer.mlp_down = Tensor::weight(nullptr, meta->dtype, down_shape);
} }
inline void getVisualWeight( inline void getVisualWeight(
const Qwen3vlMeta *meta, std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis) { const Qwen3vlMeta *meta, std::shared_ptr<Qwen3vlVisualEncoderWeight> w_vis) {
Qwen3vlVisMeta vis_meta = meta->vis_meta; Qwen3vlVisMeta vis_meta = meta->vis_meta;
auto patch_embed_shape = std::vector<size_t>({vis_meta.hidden_size , vis_meta.in_channels, vis_meta.temporal_patch_size, vis_meta.patch_size, vis_meta.patch_size}); auto patch_embed_shape = std::vector<size_t>({vis_meta.hidden_size, vis_meta.in_channels, vis_meta.temporal_patch_size, vis_meta.patch_size, vis_meta.patch_size});
w_vis->patch_embed_weight = Tensor::weight(nullptr, meta->dtype, patch_embed_shape); w_vis->patch_embed_weight = Tensor::weight(nullptr, meta->dtype, patch_embed_shape);
w_vis->patch_embed_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->patch_embed_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->pos_embed_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.num_position_embeddings, vis_meta.hidden_size}); w_vis->pos_embed_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.num_position_embeddings, vis_meta.hidden_size});
...@@ -64,10 +63,10 @@ inline void getVisualWeight( ...@@ -64,10 +63,10 @@ inline void getVisualWeight(
w_vis->merger->norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->merger->norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->blocks = std::vector<Qwen3vlVisBlockWeight>(vis_meta.depth); w_vis->blocks = std::vector<Qwen3vlVisBlockWeight>(vis_meta.depth);
for (size_t i = 0; i < vis_meta.depth; i++) { for (size_t i = 0; i < vis_meta.depth; i++) {
w_vis->blocks[i].attn_proj_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size,vis_meta.hidden_size}); w_vis->blocks[i].attn_proj_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.hidden_size});
w_vis->blocks[i].attn_proj_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->blocks[i].attn_proj_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
w_vis->blocks[i].attn_qkv_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels*vis_meta.hidden_size,vis_meta.hidden_size}); w_vis->blocks[i].attn_qkv_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels * vis_meta.hidden_size, vis_meta.hidden_size});
w_vis->blocks[i].attn_qkv_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels*vis_meta.hidden_size}); w_vis->blocks[i].attn_qkv_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.in_channels * vis_meta.hidden_size});
w_vis->blocks[i].mlp_linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.hidden_size}); w_vis->blocks[i].mlp_linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.hidden_size});
w_vis->blocks[i].mlp_linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->blocks[i].mlp_linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
w_vis->blocks[i].mlp_linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.intermediate_size}); w_vis->blocks[i].mlp_linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size, vis_meta.intermediate_size});
...@@ -78,18 +77,16 @@ inline void getVisualWeight( ...@@ -78,18 +77,16 @@ inline void getVisualWeight(
w_vis->blocks[i].norm2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size}); w_vis->blocks[i].norm2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.hidden_size});
} }
w_vis->deepstack_mergers = std::vector<DeepstackMergerWeight>(3); w_vis->deepstack_mergers = std::vector<DeepstackMergerWeight>(3);
for (size_t i = 0; i < 3; i++){ for (size_t i = 0; i < 3; i++) {
w_vis->deepstack_mergers[i].linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size,vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].linear_fc1_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size, vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size,vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].linear_fc2_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size, vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].linear_fc1_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size}); w_vis->deepstack_mergers[i].linear_fc2_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.out_hidden_size});
w_vis->deepstack_mergers[i].norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].norm_weight = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
w_vis->deepstack_mergers[i].norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size}); w_vis->deepstack_mergers[i].norm_bias = Tensor::weight(nullptr, meta->dtype, {vis_meta.intermediate_size});
} }
} }
inline std::shared_ptr<Tensor> getSinTable(const Qwen3vlMeta *meta) { inline std::shared_ptr<Tensor> getSinTable(const Qwen3vlMeta *meta) {
auto half_dh = meta->text_meta.head_dim / 2; auto half_dh = meta->text_meta.head_dim / 2;
auto unit = dsize(meta->dtype); auto unit = dsize(meta->dtype);
...@@ -172,7 +169,6 @@ Qwen3vlWeights::Qwen3vlWeights( ...@@ -172,7 +169,6 @@ Qwen3vlWeights::Qwen3vlWeights(
} }
getVisualWeight(meta, device_weights[dev]->w_vis); getVisualWeight(meta, device_weights[dev]->w_vis);
} }
} }
...@@ -201,8 +197,8 @@ void load_output_embd(Qwen3vlWeights *weights, void *cpu_ptr) { ...@@ -201,8 +197,8 @@ void load_output_embd(Qwen3vlWeights *weights, void *cpu_ptr) {
auto weight = weights->device_weights[dev]; auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->out_embd->load(cpu_ptr, weight->load_stream); weight->w_lang->out_embd->load(cpu_ptr, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->out_embd->permute({1,0}); //[d,voc] weight->w_lang->out_embd->permute({1, 0}); //[d,voc]
} }
} }
} }
...@@ -239,9 +235,8 @@ void load_attn_qkv_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -239,9 +235,8 @@ void load_attn_qkv_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(weights->meta->dtype); size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * d * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].attn_qkv_proj->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].attn_qkv_proj->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].attn_qkv_proj = weight->w_lang->layers[layer].attn_qkv_proj = weight->w_lang->layers[layer].attn_qkv_proj->permute({1, 0}); //[d, (nh+2*nkvh)*dh]
weight->w_lang->layers[layer].attn_qkv_proj->permute({1,0}); //[d, (nh+2*nkvh)*dh]
} }
} }
} }
...@@ -267,9 +262,8 @@ void load_attn_o_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -267,9 +262,8 @@ void load_attn_o_proj(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * d * (nh / ndev * dh) * dsize(weights->meta->dtype); size_t offset = idev * d * (nh / ndev * dh) * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].attn_o_proj->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].attn_o_proj->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].attn_o_proj = weight->w_lang->layers[layer].attn_o_proj = weight->w_lang->layers[layer].attn_o_proj->permute({1, 0}); //[nh/ndev*dh, d]
weight->w_lang->layers[layer].attn_o_proj->permute({1,0}); //[nh/ndev*dh, d]
} }
} }
} }
...@@ -295,9 +289,8 @@ void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -295,9 +289,8 @@ void load_mlp_gate_up(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * (2 * di / ndev) * d * dsize(weights->meta->dtype); size_t offset = idev * (2 * di / ndev) * d * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].mlp_gate_up->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].mlp_gate_up->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].mlp_gate_up = weight->w_lang->layers[layer].mlp_gate_up = weight->w_lang->layers[layer].mlp_gate_up->permute({1, 0}); //[d, 2*di/ndev]
weight->w_lang->layers[layer].mlp_gate_up->permute({1,0}); //[d, 2*di/ndev]
} }
} }
} }
...@@ -313,9 +306,8 @@ void load_mlp_down(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) { ...@@ -313,9 +306,8 @@ void load_mlp_down(Qwen3vlWeights *weights, void *cpu_ptr, size_t layer) {
size_t offset = idev * d * (di / ndev) * dsize(weights->meta->dtype); size_t offset = idev * d * (di / ndev) * dsize(weights->meta->dtype);
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id)); RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_lang->layers[layer].mlp_down->load((char *)cpu_ptr + offset, weight->load_stream); weight->w_lang->layers[layer].mlp_down->load((char *)cpu_ptr + offset, weight->load_stream);
if(weights->transpose_weight) { if (weights->transpose_weight) {
weight->w_lang->layers[layer].mlp_down = weight->w_lang->layers[layer].mlp_down = weight->w_lang->layers[layer].mlp_down->permute({1, 0}); //[di/ndev, d]
weight->w_lang->layers[layer].mlp_down->permute({1,0}); //[di/ndev, d]
} }
} }
} }
...@@ -569,7 +561,6 @@ void load_merger_norm_bias(Qwen3vlWeights *weights, void *cpu_ptr) { ...@@ -569,7 +561,6 @@ void load_merger_norm_bias(Qwen3vlWeights *weights, void *cpu_ptr) {
} }
} }
static Qwen3vlWeightLoader weight_loader = { static Qwen3vlWeightLoader weight_loader = {
// Language model loaders // Language model loaders
.lang_loader = { .lang_loader = {
...@@ -614,10 +605,9 @@ static Qwen3vlWeightLoader weight_loader = { ...@@ -614,10 +605,9 @@ static Qwen3vlWeightLoader weight_loader = {
.load_merger_linear_fc2_bias = load_merger_linear_fc2_bias, .load_merger_linear_fc2_bias = load_merger_linear_fc2_bias,
.load_merger_norm_weight = load_merger_norm_weight, .load_merger_norm_weight = load_merger_norm_weight,
.load_merger_norm_bias = load_merger_norm_bias, .load_merger_norm_bias = load_merger_norm_bias,
} }};
};
__C Qwen3vlWeights * __INFINI_C Qwen3vlWeights *
createQwen3vlWeights(const Qwen3vlMeta *meta, createQwen3vlWeights(const Qwen3vlMeta *meta,
infiniDevice_t device, infiniDevice_t device,
int ndev, int ndev,
...@@ -640,7 +630,7 @@ createQwen3vlWeights(const Qwen3vlMeta *meta, ...@@ -640,7 +630,7 @@ createQwen3vlWeights(const Qwen3vlMeta *meta,
return weights; return weights;
}; };
__C Qwen3vlWeightLoader * __INFINI_C Qwen3vlWeightLoader *
createQwen3vlWeightLoader() { createQwen3vlWeightLoader() {
return &weight_loader; return &weight_loader;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment