Commit 9ddee6b1 authored by zhuwenwen's avatar zhuwenwen
Browse files

support falcon and optimize layout

parent 2d5a25cd
......@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| BloomForCausalLM | BLOOM | Yes | No | - |
| InternLMForCausalLM | InternLM | Yes | No | - |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - |
| FalconForCausalLM | falcon | Yes | No | - |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | No | - |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - |
......
......@@ -5,12 +5,13 @@ from typing import Dict, Tuple
import numpy as np
import pytest
import os
from PIL import Image
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens)
from ..utils import urls_port
from ..utils import models_path_prefix, urls_port
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
......@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf")])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
......
......@@ -23,14 +23,14 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel']
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
os.environ['LLAMA_NN'] = '0'
else:
os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM'] or os.getenv('LM_NN') == '0':
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0'
else:
os.environ['LM_NN'] = '1'
......
......@@ -21,6 +21,8 @@
import math
from typing import Iterable, List, Optional, Tuple, Union
import os
import re
import torch
from torch import nn
from torch.nn import LayerNorm
......@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
FalconConfig = Union[HF_FalconConfig, RWConfig]
......@@ -176,6 +181,11 @@ class FalconAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward(
self,
positions: torch.Tensor,
......@@ -184,6 +194,8 @@ class FalconAttention(nn.Module):
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
if bias is not None:
qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module):
self.mlp = FalconMLP(config, quant_config)
self.config = config
if (not hasattr(config, "num_ln_in_parallel_attn")):
config.num_ln_in_parallel_attn = None
if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2
......@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def forward(
self,
input_ids: torch.LongTensor,
......@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None :
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight",
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attention.query_key_value.weight"]
qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment