Commit 2a9c497e authored by zhuwenwen's avatar zhuwenwen
Browse files

add LM_TN for bloom lm_head weight

parent 85e8224c
...@@ -22,6 +22,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -22,6 +22,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def __init__(self): def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_tn = os.environ.get('LM_TN') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -41,7 +42,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -41,7 +42,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn and not self.use_lm_tn:
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight) return torch.addmm(bias, x, layer.weight)
......
...@@ -30,6 +30,10 @@ def get_model_architecture( ...@@ -30,6 +30,10 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM']:
os.environ['LM_TN'] = '1'
else:
os.environ['LM_TN'] = '0'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
...@@ -46,6 +50,7 @@ def get_model_architecture( ...@@ -46,6 +50,7 @@ def get_model_architecture(
os.environ['AWQ_PAD'] = '0' os.environ['AWQ_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_TN'] = '0'
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
os.environ['AWQ_PAD'] = '0' os.environ['AWQ_PAD'] = '0'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment