Commit e58014d7 authored by zhuwenwen's avatar zhuwenwen
Browse files

add gemm paddig

parent e2df3544
...@@ -44,6 +44,34 @@ def adjust_bitsandbytes_shard(param: Parameter, ...@@ -44,6 +44,34 @@ def adjust_bitsandbytes_shard(param: Parameter,
return quantized_size, quantized_offset return quantized_size, quantized_offset
def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0):
if weight.dim() == 1:
padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif weight.dim() == 2:
if pad_dim == 0:
padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=0)
elif pad_dim == 1:
padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device)
padded_weight = torch.cat([weight, padding], dim=1)
else:
raise ValueError("pad_dim must be 0 or 1")
else:
raise ValueError("Weight tensor must be 1D or 2D")
return padded_weight
def gemm_bank_conf(weight):
is_mul_of_2048 = weight % 2048 == 0
is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0
if is_mul_of_2048 and is_power_of_two:
return True
else:
return False
class LinearMethodBase(QuantizeMethodBase): class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -118,7 +146,10 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -118,7 +146,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
if bias is not None: if bias is not None:
return torch.matmul(x, weight) + bias return torch.matmul(x, weight) + bias
else: else:
return torch.matmul(x, weight) if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
return torch.matmul(x, weight[:,:-32])
else:
return torch.matmul(x, weight)
else: else:
return F.linear(x, weight, bias) return F.linear(x, weight, bias)
...@@ -806,6 +837,7 @@ class RowParallelLinear(LinearBase): ...@@ -806,6 +837,7 @@ class RowParallelLinear(LinearBase):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales. # Special case for Fp8 scales.
...@@ -831,10 +863,18 @@ class RowParallelLinear(LinearBase): ...@@ -831,10 +863,18 @@ class RowParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
if self.use_llama_nn:
loaded_weight = loaded_weight.transpose(0, 1) # if self.use_llama_nn:
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1) # loaded_weight = loaded_weight.transpose(0, 1)
# loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# param_data.copy_(loaded_weight)
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
if self.use_llama_nn:
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
def forward(self, input_): def forward(self, input_):
# Set up backprop all-reduce. # Set up backprop all-reduce.
......
...@@ -25,6 +25,8 @@ def get_model_architecture( ...@@ -25,6 +25,8 @@ def get_model_architecture(
if architectures == ['LlamaForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']: if architectures == ['LlamaForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '0':
os.environ['GEMM_PAD'] = '1'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None if (model_config.quantization is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment