Commit baf7d06a authored by zhuwenwen's avatar zhuwenwen
Browse files

1. 修复w8a8找不到对应config的bug

2. 删除vllm的w8a8 config
parent c5888d31
......@@ -370,6 +370,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
......@@ -384,20 +385,18 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
#k=weight_data.shape[1]
#print("n:{},k:{}".format(n,k))
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
#("weight_shapes:",weight_shapes)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
......
......@@ -565,6 +565,7 @@ class LlamaModel(nn.Module):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
......@@ -582,7 +583,8 @@ class LlamaModel(nn.Module):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
......
......@@ -460,6 +460,7 @@ class QWenBaseModel(nn.Module):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername in loaded_params:
weight = params_dict[layername]
......@@ -474,7 +475,8 @@ class QWenBaseModel(nn.Module):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
......
......@@ -539,6 +539,7 @@ class Qwen2Model(nn.Module):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername in loaded_params:
weight = params_dict[layername]
......@@ -553,7 +554,8 @@ class Qwen2Model(nn.Module):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
......
......@@ -1505,8 +1505,6 @@ class W8a8GetCacheJSON:
def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../lmslim/configs/w8a8'
if not os.path.exists(json_folder_path):
json_folder_path=current_folder_path+'/model_executor/layers/quantization/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict=[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment