Commit b8d41f2f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev-w8a8' into 'v0.7.2-dev'

V0.7.2 dev w8a8

See merge request dcutoolkit/deeplearing/vllm!69
parents 49ff8ab5 41d0696e
......@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
......@@ -86,11 +87,11 @@ class CompressedTensorsConfig(QuantizationConfig):
if should_ignore_layer(prefix,
ignore=self.ignore,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
if scheme is None:
return UnquantizedLinearMethod()
return UnquantizedEmbeddingMethod()#UnquantizedLinearMethod()
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
......
......@@ -255,8 +255,10 @@ def apply_int8_linear(
k=x_q.shape[1]
n=weight.shape[1]
#print("m:{},n:{},k:{}".format(m,n,k))
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
if f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
if m<=16:
m_=m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
......
......@@ -18,7 +18,8 @@
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
import os
import re
import torch
from torch import nn
from transformers import GPTNeoXConfig
......@@ -45,7 +46,8 @@ from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.utils import is_hip,W8a8GetCacheJSON
from vllm import _custom_ops as ops
class GPTNeoXAttention(nn.Module):
......@@ -278,6 +280,13 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
......@@ -349,4 +358,52 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [
"attention.query_key_value.weight",
"attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
k=weight_data.shape[1]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
#k=weight_data.shape[1]
#print("n:{},k:{}".format(n,k))
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
#("weight_shapes:",weight_shapes)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params
......@@ -355,6 +355,7 @@ class LlamaModel(nn.Module):
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
......@@ -549,6 +550,7 @@ class LlamaModel(nn.Module):
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -644,7 +646,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.tritonsingleton= W8a8GetCacheJSON()
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
......
......@@ -1174,6 +1174,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [
"attn.c_attn.weight",
"attn.c_proj.weight",
......
......@@ -527,6 +527,7 @@ class Qwen2Model(nn.Module):
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......
......@@ -58,6 +58,7 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
import json
# Exception strings for non-implemented encoder/decoder scenarios
......@@ -1608,6 +1609,7 @@ class W8a8GetCacheJSON:
return configs_dict
def get_w8a8json_name(self,n,k):
from vllm.platforms import current_platform
device_name = current_platform.get_device_name().replace(" ", "_")
return self.triton_json_dir+f"/W8A8_{n}_{k}_{device_name}.json"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment