Commit 083b80ea authored by zhuwenwen's avatar zhuwenwen
Browse files

增加w8a8相关修改

parent 09428eec
...@@ -238,27 +238,40 @@ def apply_int8_linear( ...@@ -238,27 +238,40 @@ def apply_int8_linear(
m=x_q.shape[0] m=x_q.shape[0]
k=x_q.shape[1] k=x_q.shape[1]
n=weight.shape[1] n=weight.shape[1]
if f"{m}_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]: #print("m:{},n:{},k:{}".format(m,n,k))
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
#print("json files:",best_config) if f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]:
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]: if m<=16:
if m<64: m_=m
m_= 32 #best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
elif m<128: elif m<=64:
m_=64 m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<256: elif m<=160:
m_=128 m_=(m + 7) & -8
elif m<512:
elif m<200: #256
m_=160
elif m<480: #512
m_=256 m_=256
elif m<1024: elif m<960: #1024
m_=512 m_=512
else: elif m<2048:
m_=1024 m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{n}_{k}"] best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{n}_{k}"]
else: else:
best_config=None best_config=None
print("config not found!") # if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return ops.triton_scaled_mm(x_q, return ops.triton_scaled_mm(x_q,
weight, weight,
......
...@@ -462,7 +462,7 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -462,7 +462,7 @@ class FalconForCausalLM(nn.Module, SupportsPP):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids) return self.transformer.get_input_embeddings(input_ids)
......
...@@ -355,7 +355,7 @@ class LlamaModel(nn.Module): ...@@ -355,7 +355,7 @@ class LlamaModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -553,6 +553,10 @@ class LlamaModel(nn.Module): ...@@ -553,6 +553,10 @@ class LlamaModel(nn.Module):
if matches and "scale" not in layername: if matches and "scale" not in layername:
weight_data =params_dict[layername] weight_data =params_dict[layername]
n=weight_data.shape[0] n=weight_data.shape[0]
# k=weight_data.shape[1]
# #判断当前size是否在优化的范围内,假如存在则走triton,假如不存在则走rocblas
# json_file=self.tritonsingleton.get_w8a8json_name(n,k)
#rocblas和cutlass目前都需要weight做处理,但是triton不用 #rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1: if self.w8a8_strategy!=1:
......
...@@ -913,12 +913,13 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): ...@@ -913,12 +913,13 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def _get_image_input_type( def _get_image_input_type(
self, self,
...@@ -1147,7 +1148,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): ...@@ -1147,7 +1148,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
m=int(key.split('_')[0]) m=int(key.split('_')[0])
n=int(key.split('_')[1]) n=int(key.split('_')[1])
k=int(key.split('_')[2]) k=int(key.split('_')[2])
ops._int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment