Commit e5d707db authored by zhuwenwen's avatar zhuwenwen
Browse files

support tn/nn

parent fc92ed40
...@@ -22,7 +22,8 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -22,7 +22,8 @@ def set_default_torch_dtype(dtype: torch.dtype):
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
if architectures == ['LlamaForCausalLM']: if architectures == ['LlamaForCausalLM'] or architectures == ['ChatGLMModel'] or architectures == ['BaichuanForCausalLM']:
if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment