Commit acf1b6c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

修改channelwise moe相关配置读取

parent 6781a21e
...@@ -139,6 +139,7 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -139,6 +139,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
assert self.quant_config.weight_block_size is not None assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized assert self.quant_config.is_checkpoint_int8_serialized
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights( def create_weights(
self, self,
...@@ -438,8 +439,8 @@ class BlockInt8MoEMethod: ...@@ -438,8 +439,8 @@ class BlockInt8MoEMethod:
TOPK= self.tritonsingleton.topk TOPK= self.tritonsingleton.topk
block_size=self.quant_config.weight_block_size block_size=self.quant_config.weight_block_size
json_file=self.tritonsingleton.get_moeblockint8json_name(block_size,E,N1,N2,K,TOPK) json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,block_size,)
configs_dict=self.tritonsingleton.get_moeblockint8_triton_cache(json_file,block_size,E,N1,N2,K,TOPK) configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup #warmup
if configs_dict: if configs_dict:
......
...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import W8a8GetCacheJSON
has_pplx = importlib.util.find_spec("pplx_kernels") is not None has_pplx = importlib.util.find_spec("pplx_kernels") is not None
...@@ -141,6 +142,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -141,6 +142,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
is_rocm_aiter_moe_enabled) is_rocm_aiter_moe_enabled)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
...@@ -225,6 +227,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -225,6 +227,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
else: else:
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
#self.tritonsingleton.gen_model_json(block_size)
return
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
......
...@@ -1994,10 +1994,14 @@ class W8a8GetCacheJSON: ...@@ -1994,10 +1994,14 @@ class W8a8GetCacheJSON:
def get_blockint8json_name(self,n,k,block_n,block_k): def get_blockint8json_name(self,n,k,block_n,block_k):
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json" return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeblockint8json_name(self,block_size,E,N1,N2,K,TOPK): def get_moeint8json_name(self,E,N1,N2,K,TOPK,
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json" block_size:Optional[list]=None):
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeblockint8_triton_cache(self,file_path,block_size,E,N1,N2,K,TOPK): def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
cache_json_file=file_path cache_json_file=file_path
if os.path.exists(file_path): if os.path.exists(file_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment