Commit 7fe40ced authored by zhuwenwen's avatar zhuwenwen
Browse files

modify gemm pad strategy

parent e661266e
......@@ -867,17 +867,19 @@ class RowParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape
# if self.use_llama_nn:
# loaded_weight = loaded_weight.transpose(0, 1)
# loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
# param_data.copy_(loaded_weight)
param_data.copy_(loaded_weight)
if self.use_llama_nn:
if not self.use_gemm_pad:
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight=loaded_weight.reshape(param_data.shape[0],-1)
param_data.copy_(loaded_weight)
else:
param_data.copy_(loaded_weight)
if gemm_bank_conf(param.data.shape[0]) and self.use_gemm_pad:
param.data = pad_weight(param.data, 32)
param.data = param.data.transpose(0, 1)
param.data=param.data.reshape(param.data.shape[1],-1)
else:
param_data.copy_(loaded_weight)
def forward(self, input_):
# Set up backprop all-reduce.
......
......@@ -25,8 +25,8 @@ def get_model_architecture(
if architectures == ['LlamaForCausalLM'] or architectures == ['Qwen2ForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1'
if os.getenv('GEMM_PAD') != '0':
os.environ['GEMM_PAD'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment