Commit 0a130908 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev-w8a8' into 'v0.7.2-dev'

V0.7.2 dev w8a8

See merge request dcutoolkit/deeplearing/vllm!76
parents 6b7651af 6fd8c21b
......@@ -370,7 +370,8 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
......@@ -384,20 +385,19 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
#k=weight_data.shape[1]
#print("n:{},k:{}".format(n,k))
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
#("weight_shapes:",weight_shapes)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
......
......@@ -562,6 +562,7 @@ class LlamaModel(nn.Module):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
......@@ -579,7 +580,8 @@ class LlamaModel(nn.Module):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
......
......@@ -1185,6 +1185,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername in loaded_params:
weight = params_dict[layername]
......@@ -1199,7 +1200,8 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
......
......@@ -539,6 +539,7 @@ class Qwen2Model(nn.Module):
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername in loaded_params:
weight = params_dict[layername]
......@@ -553,7 +554,8 @@ class Qwen2Model(nn.Module):
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(weight_shapes)<4:
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
......
......@@ -1534,8 +1534,6 @@ class W8a8GetCacheJSON:
def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../lmslim/configs/w8a8'
if not os.path.exists(json_folder_path):
json_folder_path=current_folder_path+'/model_executor/layers/quantization/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict=[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment