Commit 1a11f127 authored by zhuwenwen's avatar zhuwenwen
Browse files

support qwen2-vl video

parent 083b80ea
......@@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers == 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
transformers == 4.47.0 # Required for Llama 3.2 and Qwen2-VL.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
......
......@@ -26,7 +26,8 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM']
# 'Qwen2VLForConditionalGeneration'
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
......@@ -430,11 +430,11 @@ class Qwen2Model(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"]
qkv_words = "|".join(lay_qkv_words)
# lay_qkv_words = ["self_attn.qkv_proj.weight"]
# qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
# lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
if "lm_head.weight" in layername and weight.shape[1] >= 3584:
......
......@@ -535,6 +535,16 @@ class Qwen2VisionTransformer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
@property
def dtype(self) -> torch.dtype:
......@@ -658,11 +668,11 @@ class Qwen2VisionTransformer(nn.Module):
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["attn.qkv.weight"]
qkv_words = "|".join(lay_qkv_words)
# lay_qkv_words = ["attn.qkv.weight"]
# qkv_words = "|".join(lay_qkv_words)
lay_qkv_bias_words = ["attn.qkv.bias"]
qkv_bias_words = "|".join(lay_qkv_bias_words)
# lay_qkv_bias_words = ["attn.qkv.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername, weight in params_dict.items():
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
......@@ -670,8 +680,8 @@ class Qwen2VisionTransformer(nn.Module):
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
# weight.data = pad_weight(weight.data, 32)
# if self.use_fa_pad and (re.findall(qkv_words, layername)):
# if not gemm_bank_conf(weight.data.shape[0]):
......@@ -983,16 +993,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
@cached_property
def sampler(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment