Commit 823b3332 authored by gaoqiong's avatar gaoqiong
Browse files

修复0.8.5 config找不到的bug

parent 1150b65c
...@@ -529,6 +529,14 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] ...@@ -529,6 +529,14 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
}); });
break; break;
case 8:
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "moe_sum_kernel", [&] {
vllm::moe::moe_sum_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
hidden_size);
});
break;
default: default:
at::sum_out(output, input, 1); at::sum_out(output, input, 1);
break; break;
......
...@@ -31,7 +31,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ...@@ -31,7 +31,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer) should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
...@@ -540,10 +543,32 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -540,10 +543,32 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig): def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer) n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: List[int], input_size: int,
......
...@@ -453,12 +453,12 @@ def w8a8_block_int8_matmul( ...@@ -453,12 +453,12 @@ def w8a8_block_int8_matmul(
# "num_stages": 3, # "num_stages": 3,
# } # }
#print("W8A8_TRITONJSON.triton_json_dict[0]:",W8A8_TRITONJSON.triton_json_dict[0]) #print("W8A8_TRITONJSON.triton_json_list[0]:",W8A8_TRITONJSON.triton_json_list[0])
if len(W8A8_TRITONJSON.triton_json_dict)==0: if len(W8A8_TRITONJSON.triton_json_list)==0:
config=None config=None
#print("len(W8A8_TRITONJSON.triton_json_dict)=0:",len(W8A8_TRITONJSON.triton_json_dict)) #print("len(W8A8_TRITONJSON.triton_json_list)=0:",len(W8A8_TRITONJSON.triton_json_list)) triton_json
elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_dict[0]: elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_list[0]:
if M<=16: if M<=16:
m_=M m_=M
elif M<=64: elif M<=64:
...@@ -481,7 +481,7 @@ def w8a8_block_int8_matmul( ...@@ -481,7 +481,7 @@ def w8a8_block_int8_matmul(
else: else:
m_=8192 m_=8192
#print("==================m:{},n:{},k:{}".format(M,N,K)) #print("==================m:{},n:{},k:{}".format(M,N,K))
config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"] config=W8A8_TRITONJSON.triton_json_list[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"]
else: else:
config=None config=None
......
...@@ -420,7 +420,7 @@ def apply_int8_linear( ...@@ -420,7 +420,7 @@ def apply_int8_linear(
if len(W8A8_TRITONJSON.triton_json_dict)==0: if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]: elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16: if m<=16:
m_=m m_=m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"] #best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
...@@ -444,7 +444,7 @@ def apply_int8_linear( ...@@ -444,7 +444,7 @@ def apply_int8_linear(
else: else:
m_=8192 m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{n}_{k}"] best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else: else:
best_config=None best_config=None
......
...@@ -961,7 +961,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -961,7 +961,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if configs_dict: if configs_dict:
all_json.update(configs_dict) all_json.update(configs_dict)
self.tritonsingleton.triton_json_dict.append(all_json) self.tritonsingleton.triton_json_list.append(all_json)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0])) #print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for key, value in all_json.items(): for key, value in all_json.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
......
...@@ -46,7 +46,6 @@ from .interfaces import SupportsPP ...@@ -46,7 +46,6 @@ from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm.utils import is_hip,W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -219,12 +218,10 @@ class GPTNeoXModel(nn.Module): ...@@ -219,12 +218,10 @@ class GPTNeoXModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids) return self.embed_in(input_ids)
...@@ -287,53 +284,7 @@ class GPTNeoXModel(nn.Module): ...@@ -287,53 +284,7 @@ class GPTNeoXModel(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [
"attention.query_key_value.weight",
"attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
k=weight_data.shape[1]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -50,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import W8a8GetCacheJSON
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
...@@ -356,14 +355,12 @@ class LlamaModel(nn.Module): ...@@ -356,14 +355,12 @@ class LlamaModel(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# self.use_lm_nn = os.environ.get('LM_NN') == '1' # self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -517,97 +514,7 @@ class LlamaModel(nn.Module): ...@@ -517,97 +514,7 @@ class LlamaModel(nn.Module):
else: else:
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
# if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
# lay_key_words = [
# "self_attn.qkv_proj.qweight",
# "self_attn.o_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.down_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
# k=weight_data.shape[1]
# #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
# json_file=self.tritonsingleton.get_w8a8json_name(n,k)
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -39,7 +39,6 @@ from .utils import (is_pp_missing_parameter, ...@@ -39,7 +39,6 @@ from .utils import (is_pp_missing_parameter,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -291,13 +290,11 @@ class QWenBaseModel(nn.Module): ...@@ -291,13 +290,11 @@ class QWenBaseModel(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def compute_logits( def compute_logits(
self, self,
...@@ -384,51 +381,7 @@ class QWenBaseModel(nn.Module): ...@@ -384,51 +381,7 @@ class QWenBaseModel(nn.Module):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
# if self.quant_method == "awq":
# os.environ['LM_NN'] = '0'
# lay_key_words = [
# "attn.c_attn.qweight",
# "attn.c_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.c_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"attn.c_attn.weight", "attn.c_attn.weight",
......
...@@ -63,7 +63,6 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, ...@@ -63,7 +63,6 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.utils import W8a8GetCacheJSON
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -338,13 +337,11 @@ class Qwen2Model(nn.Module): ...@@ -338,13 +337,11 @@ class Qwen2Model(nn.Module):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -485,93 +482,6 @@ class Qwen2Model(nn.Module): ...@@ -485,93 +482,6 @@ class Qwen2Model(nn.Module):
else: else:
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
# if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ:
# lay_key_words = [
# "self_attn.qkv_proj.qweight",
# "self_attn.o_proj.qweight",
# "mlp.gate_up_proj.qweight",
# "mlp.down_proj.qweight"
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# qweight =params_dict[layername]
# qzeros=params_dict[layername.replace("qweight", "qzeros")]
# scales=params_dict[layername.replace("qweight", "scales")]
# zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
# group_size= self.quant_config.group_size
# dim_n = scales.data.shape[1]
# dim_k = qweight.data.shape[0]
# pad_group=2
# _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
# sz = ops.sz_permute(_sz).reshape(-1,dim_n)
# zeros_and_scalse.data.copy_(sz)
# qweight.data.copy_(_qw)
# #reshape
# zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
# qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
# if dim_k % 4096==0 and self.use_awq_pad:
# zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
# zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
# qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
# qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
if self.quant_method == "compressed_tensors":
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
...@@ -1725,7 +1725,9 @@ class W8a8GetCacheJSON: ...@@ -1725,7 +1725,9 @@ class W8a8GetCacheJSON:
json_folder_path=current_folder_path+'/../lmslim/configs/w8a8' json_folder_path=current_folder_path+'/../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path)) self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict=[] self.triton_json_dict={}
self.triton_json_list=[]
self.weight_shapes=[]
def getspec_config(self,configs_dict,M,N,K): def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict: if f"{M}_{N}_{K}" in configs_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment