Commit caf953b6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update chatglm nn layout

parent 66a7ebd8
......@@ -548,6 +548,15 @@ class ChatGLMModel(nn.Module):
self.make_empty_intermediate_tensors = (
self.encoder.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def _parse_and_validate_image_input(
self, **kwargs: object) -> GLMImagePixelInputs:
......@@ -732,14 +741,6 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = get_sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
def forward(self,
input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment