Commit e58014d7 authored by zhuwenwen's avatar zhuwenwen
Browse files

add gemm paddig

parent e2df3544
......@@ -44,6 +44,34 @@ def adjust_bitsandbytes_shard(param: Parameter,
return quantized_size, quantized_offset
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
......@@ -118,7 +146,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias is not None:
return torch.matmul(x, weight) + bias
else:
return torch.matmul(x, weight)
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
return torch.matmul(x, weight[:,:-32])
else:
return torch.matmul(x, weight)
else:
return F.linear(x, weight, bias)
......@@ -806,6 +837,7 @@ class RowParallelLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
......@@ -831,10 +863,18 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# if self.use_llama_nn:
# loaded_weight = loaded_weight.transpose(0, 1)
# loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# param_data.copy_(loaded_weight)
param_data.copy_(loaded_weight)
if self.use_llama_nn:
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
def forward(self, input_):
# Set up backprop all-reduce.
......
......@@ -25,6 +25,8 @@ def get_model_architecture(
if architectures == ['LlamaForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '0':
os.environ['GEMM_PAD'] = '1'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment