Commit c6187ade authored by jujl1's avatar jujl1
Browse files

fix: 修复W8A8INT8读config问题

parent 7017f30c
......@@ -1516,6 +1516,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int,
......@@ -1570,6 +1571,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
#self.tritonsingleton.gen_model_json(block_size)
return
def apply(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment