Commit 083b80ea authored by zhuwenwen's avatar zhuwenwen
Browse files

增加w8a8相关修改

parent 09428eec
...@@ -60,6 +60,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, ...@@ -60,6 +60,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -329,12 +330,13 @@ class Qwen2Model(nn.Module): ...@@ -329,12 +330,13 @@ class Qwen2Model(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -510,14 +512,39 @@ class Qwen2Model(nn.Module): ...@@ -510,14 +512,39 @@ class Qwen2Model(nn.Module):
"mlp.down_proj.weight", "mlp.down_proj.weight",
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches and "scale" not in layername:
weight_data =params_dict[layername] weight_data =params_dict[layername]
k=weight_data.shape[0] n=weight_data.shape[0]
_weight=weight_data.T.contiguous().reshape(k,-1)
weight_data.data.copy_(_weight) #rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -1466,7 +1466,11 @@ class W8a8GetCacheJSON: ...@@ -1466,7 +1466,11 @@ class W8a8GetCacheJSON:
def _initialize(self): def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__)) current_folder_path = os.path.dirname(os.path.abspath(__file__))
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', current_folder_path+'/model_executor/layers/quantization/configs/w8a8')) json_folder_path=current_folder_path+'/../lmslim/configs/w8a8'
if not os.path.exists(json_folder_path):
json_folder_path=current_folder_path+'/model_executor/layers/quantization/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict=[] self.triton_json_dict=[]
def getspec_config(self,configs_dict,M,N,K): def getspec_config(self,configs_dict,M,N,K):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment