Commit bf3bf955 authored by zhuwenwen's avatar zhuwenwen
Browse files

修改美团 deepseek channel-wise模型 moe config获取

去除AutoTuning info提示信息
parent acf1b6c6
...@@ -248,6 +248,7 @@ class W8A8Int8MoEMethod: ...@@ -248,6 +248,7 @@ class W8A8Int8MoEMethod:
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights( def create_weights(
self, self,
...@@ -302,6 +303,22 @@ class W8A8Int8MoEMethod: ...@@ -302,6 +303,22 @@ class W8A8Int8MoEMethod:
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter( layer.w13_weight_scale = Parameter(
......
...@@ -1904,7 +1904,7 @@ class W8a8GetCacheJSON: ...@@ -1904,7 +1904,7 @@ class W8a8GetCacheJSON:
json_dir = os.getenv('LMSLIM_TUNING_JSON', "None") json_dir = os.getenv('LMSLIM_TUNING_JSON', "None")
if json_dir is not "None" and os.path.exists(json_dir): if json_dir is not "None" and os.path.exists(json_dir):
#生成模型配置文件 #生成模型配置文件
logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir) # logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config = { config = {
"layers": { "layers": {
"linear": { "linear": {
...@@ -1942,8 +1942,8 @@ class W8a8GetCacheJSON: ...@@ -1942,8 +1942,8 @@ class W8a8GetCacheJSON:
with open(json_dir+"/model.json", 'w') as f: with open(json_dir+"/model.json", 'w') as f:
json.dump(config, f, indent=4) json.dump(config, f, indent=4)
else: # else:
logger.info("LMSLIM_TUNING_JSON is not set") # logger.info("LMSLIM_TUNING_JSON is not set")
def getspec_config(self,configs_dict,M,N,K): def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict: if f"{M}_{N}_{K}" in configs_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment