Commit a8134c13 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5.post-dev-w8a8' into 'v0.8.5.post1-dev'

V0.8.5.post dev w8a8

See merge request dcutoolkit/deeplearing/vllm!131
parents a68aef25 53250530
......@@ -23,6 +23,10 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.utils.int8_utils import (
apply_w8a8_block_int8_linear)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -128,6 +132,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: BlockInt8Config):
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized
......@@ -219,6 +224,27 @@ class BlockInt8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
# Use torch Parameter to avoid cuda graph capturing issue
n=layer.weight.shape[0]
k=layer.weight.shape[1]
block_n=self.quant_config.weight_block_size[0]
block_k=self.quant_config.weight_block_size[1]
block_size=[block_n,block_k]
#print("layer.weight.device:",layer.weight.device)
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_blockint8json_name(n,k,block_n,block_k)
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,block_n,block_k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
#ops.triton_blockint8_gemm_helper(m=m,n=n,k=k,block_size=block_size,use_bias=False,out_dtype=torch.bfloat16,device=layer.weight.device,best_config=value)
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
......
......@@ -53,7 +53,6 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import W8a8GetCacheJSON
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
......@@ -704,7 +703,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os.environ['LM_NN'] = '0'
self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1'
self.tritonsingleton= W8a8GetCacheJSON()
self.config = config
self.quant_config = quant_config
......@@ -928,48 +926,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
sz_tensor = self.restore_qzeros_tensor(qzeros, scales)
scales.data = sz_tensor
if hasattr(self.config, "quantization_config") and self.config.quantization_config["quant_method"] == "blockwise_int8":
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"mlp.shared_experts.gate_up_proj.weight",
"mlp.shared_experts.down_proj.weight"
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
if len(matched_key_words) < 9 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
k=weight_data.shape[1]
weight_shapes.append({n,k})
#print("n:{},k:{}".format(n,k))
json_file=self.tritonsingleton.get_blockint8json_name(n,k,128,128)
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,128,128)
if configs_dict:
all_json.update(configs_dict)
self.tritonsingleton.triton_json_list.append(all_json)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
# ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment