Commit 9ddee6b1 authored by zhuwenwen's avatar zhuwenwen
Browse files

support falcon and optimize layout

parent 2d5a25cd
...@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -19,6 +19,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| BloomForCausalLM | BLOOM | Yes | No | - | | BloomForCausalLM | BLOOM | Yes | No | - |
| InternLMForCausalLM | InternLM | Yes | No | - | | InternLMForCausalLM | InternLM | Yes | No | - |
| InternLM2ForCausalLM | InternLM2 | Yes | No | - | | InternLM2ForCausalLM | InternLM2 | Yes | No | - |
| FalconForCausalLM | falcon | Yes | No | - |
| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | No | - | | TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | No | - |
| MiniCPMForCausalLM | MiniCPM | Yes | No | - | | MiniCPMForCausalLM | MiniCPM | Yes | No | - |
| MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - | | MiniCPM3ForCausalLM | MiniCPM3 | Yes | No | - |
......
...@@ -5,12 +5,13 @@ from typing import Dict, Tuple ...@@ -5,12 +5,13 @@ from typing import Dict, Tuple
import numpy as np import numpy as np
import pytest import pytest
import os
from PIL import Image from PIL import Image
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image, from vllm.multimodal.utils import (async_fetch_image, fetch_image,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from ..utils import urls_port from ..utils import models_path_prefix, urls_port
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_URLS = [
...@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], ...@@ -85,7 +86,7 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
assert _image_equals(data_image_sync, data_image_async) assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "llava-hf/llava-v1.6-mistral-7b-hf")])
def test_repeat_and_pad_placeholder_tokens(model): def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index image_token_id = config.image_token_index
......
...@@ -23,14 +23,14 @@ def get_model_architecture( ...@@ -23,14 +23,14 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'Qwen2VLForConditionalGeneration', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM'] or os.getenv('LM_NN') == '0': if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
import math import math
from typing import Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import os
import re
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
...@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -47,6 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -176,6 +181,11 @@ class FalconAttention(nn.Module): ...@@ -176,6 +181,11 @@ class FalconAttention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -184,6 +194,8 @@ class FalconAttention(nn.Module): ...@@ -184,6 +194,8 @@ class FalconAttention(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
qkv = qkv[...,:-32]
if bias is not None: if bias is not None:
qkv += bias qkv += bias
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module): ...@@ -246,6 +258,9 @@ class FalconDecoderLayer(nn.Module):
self.mlp = FalconMLP(config, quant_config) self.mlp = FalconMLP(config, quant_config)
self.config = config self.config = config
if (not hasattr(config, "num_ln_in_parallel_attn")):
config.num_ln_in_parallel_attn = None
if (config.num_ln_in_parallel_attn is None if (config.num_ln_in_parallel_attn is None
and config.new_decoder_architecture): and config.new_decoder_architecture):
config.num_ln_in_parallel_attn = 2 config.num_ln_in_parallel_attn = 2
...@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module): ...@@ -404,6 +419,17 @@ class FalconForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
...@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module): ...@@ -481,3 +507,33 @@ class FalconForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn and self.quant_method is None :
lay_key_words = [
"self_attention.query_key_value.weight",
"self_attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight",
]
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attention.query_key_value.weight"]
qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
if self.use_fa_pad and (re.findall(qkv_words, layername)):
if not gemm_bank_conf(weight.data.shape[0]):
weight.data = pad_weight(weight.data, 32)
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment